{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ExistentialQuantification #-}
module Test.Divide (
   testsVar,
   testsVarAny,
   SquareMatrix(SquareMatrix),
   determinant,
   ) where

import qualified Test.Generator as Gen
import qualified Test.Utility as Util
import Test.Generator ((<#\#>), (<#/#>))
import Test.Utility (Tagged, approxMatrix)

import qualified Numeric.LAPACK.Linear.LowerUpper as LU
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Special as Special
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array (ArrayMatrix)
import Numeric.LAPACK.Matrix (Matrix, ShapeInt, (##/#), (##*#), (#*##), (#\##))
import Numeric.LAPACK.Scalar (RealOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape

import Control.Applicative ((<$>))

import Data.Tuple.HT (mapSnd)

import qualified Test.QuickCheck as QC



determinant ::
   (ArrMatrix.Determinant shape, ArrMatrix.SquareShape shape,
    Box.HeightOf shape ~ ShapeInt, Box.WidthOf shape ~ ShapeInt,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   ArrayMatrix shape a -> Bool
determinant a =
   Util.approx
      (Scalar.selectReal 1e-1 1e-5)
      (Matrix.determinant a)
      (Square.determinant $ Matrix.toSquare a)


multiplySolveTrans ::
   (Shape.C size, Eq size, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   LU.Transposition ->
   (SquareMatrix size a, Matrix.General size ShapeInt a) -> Bool
multiplySolveTrans trans (SquareMatrix a, b) =
   approxMatrix 1e-2 b $
      Matrix.multiplySquare trans a $ Matrix.solve trans a b

multiplySolveRight ::
   (Shape.C size, Eq size, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (SquareMatrix size a, Matrix.General size ShapeInt a) -> Bool
multiplySolveRight (SquareMatrix a, b) =
   approxMatrix 1e-2 b (a #*## (a #\## b))

multiplySolveLeft ::
   (Shape.C size, Eq size, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Matrix.General ShapeInt size a, SquareMatrix size a) -> Bool
multiplySolveLeft (b, SquareMatrix a) =
   approxMatrix 1e-2 b ((b ##/# a) ##*# a)

multiplyInverseRight ::
   (Shape.C size, Eq size, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (SquareMatrix size a, Matrix.General size ShapeInt a) -> Bool
multiplyInverseRight (SquareMatrix a, b) =
   approxMatrix 1e-2 b (a #*## (Special.Inverse a #*## b))

multiplyInverseLeft ::
   (Shape.C size, Eq size, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Matrix.General ShapeInt size a, SquareMatrix size a) -> Bool
multiplyInverseLeft (b, SquareMatrix a) =
   approxMatrix 1e-2 b ((b ##*# Special.Inverse a) ##*# a)


checkForAll ::
   (Show a, QC.Testable test) =>
   Gen.T dim tag a -> (a -> test) -> Tagged tag QC.Property
checkForAll gen = Util.checkForAll (Gen.run gen 3 5)


data SquareMatrix size a =
   forall typ matrix.
   (Matrix.Solve typ, Matrix.MultiplySquare typ,
    Matrix typ a ~ matrix, Show matrix,
    Matrix.HeightOf typ ~ size, Matrix.WidthOf typ ~ size) =>
   SquareMatrix (Matrix typ a)

instance Show (SquareMatrix size a) where
   show (SquareMatrix m) = show m


testsVarAny ::
   (Show a, Class.Floating a, Eq a, RealOf a ~ ar, Class.Real ar) =>
   [(String, Gen.MatrixInt a (SquareMatrix ShapeInt a) -> Tagged a QC.Property)]
testsVarAny =
   ("multiplySolveTrans",
      \gen ->
         Gen.withExtra checkForAll QC.arbitraryBoundedEnum
            ((,) <$> gen <#\#> Gen.matrix) multiplySolveTrans) :
   ("multiplySolveRight",
      \gen ->
         checkForAll ((,) <$> gen <#\#> Gen.matrix) multiplySolveRight) :
   ("multiplySolveLeft",
      \gen ->
         checkForAll ((,) <$> Gen.matrix <#/#> gen) multiplySolveLeft) :
   ("multiplyInverseRight",
      \gen ->
         checkForAll ((,) <$> gen <#\#> Gen.matrix) multiplyInverseRight) :
   ("multiplyInverseLeft",
      \gen ->
         checkForAll ((,) <$> Gen.matrix <#/#> gen) multiplyInverseLeft) :
   []

testsVar ::
   (Matrix.Solve typ, Matrix.MultiplySquare typ,
    Matrix typ a ~ matrix, Show matrix,
    Matrix.HeightOf typ ~ ShapeInt, Matrix.WidthOf typ ~ ShapeInt,
    Show a, Class.Floating a, Eq a, RealOf a ~ ar, Class.Real ar) =>
   Gen.MatrixInt a (Matrix typ a) -> [(String, Tagged a QC.Property)]
testsVar gen =
   map (mapSnd ($ (SquareMatrix <$> gen))) testsVarAny