{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE EmptyDataDecls #-}
module Test.Generator where

import qualified Test.Logic as Logic
import qualified Test.Utility as Util
import Test.Logic (Dim, MatchMode(DontForceMatch,ForceMatch), (=!=), (<!=))
import Test.Utility (Match)

import qualified UniqueLogic.ST.TF.System.Simple as Sys

import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Hermitian (Hermitian)
import Numeric.LAPACK.Matrix (ZeroInt, zeroInt)
import Numeric.LAPACK.Scalar (RealOf, fromReal, one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((:+:))

import qualified Control.Monad.Trans.RWS as MRWS
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Applicative.HT as AppHT
import Control.Monad.Trans.RWS (RWST, evalRWST)
import Control.Applicative (liftA2, (<*>), (<$>))

import qualified Data.Ref as Ref
import Data.Semigroup ((<>))
import Data.Monoid (Monoid, mempty)
import Data.Tuple.HT (swap, mapFst, mapSnd)

import qualified Test.QuickCheck as QC



{- |
@Cons generator@ with @generator maxElem@.
@generator@ constructs an array and maintains relations between the dimensions.
Dimensions will be choosen arbitrarily from the range @(0,maxDim)@.
Elements are choosen from the range @(-maxElem,maxElem)@.

I moved the 's' tag to within the 'Cons' constructor
and furthermore defined 'TaggedVariables' to strip the 's' tag
from the Variables in 'dim'.
This way, we can easily define 'checkForAll' in the test modules.
Otherwise there would not be a way to quantify 'dim' while containing 's' tags.
That is, we would have to reset 'dim' to () before every call to 'checkForAll'.
-}
newtype T tag dim array =
   Cons {
      decons :: forall s.
         RWST Integer (Logic.System s) () (Logic.M s)
            (TaggedVariables s dim, Logic.M s array)
   }

data Variable dim

type family TaggedVariables s tuple
type instance TaggedVariables s (Variable dim) = Logic.Variable s dim
type instance TaggedVariables s () = ()
type instance TaggedVariables s (a,b) =
                  (TaggedVariables s a, TaggedVariables s b)

instance Functor (T tag dim) where
   fmap f (Cons gen) = Cons $ mapSnd (fmap f) <$> gen

newVariable :: (Ref.C m, Monoid w) => RWST r w () m (Sys.Variable m a)
newVariable = MT.lift Sys.globalVariable

run ::
   T tag dim array -> Integer -> Int ->
   Util.TaggedGen tag (array, Match)
run gen maxElem maxDim =
   Util.Tagged $
      QC.elements [DontForceMatch, ForceMatch] >>=
      Logic.runSTInGen
         (do ((_dim, queries), sys) <- evalRWST (decons gen) maxElem ()
             Logic.solve sys
             queries)
         maxDim

withExtra ::
   (T tag dim (a,b) -> ((a,b) -> c) -> io) ->
   QC.Gen a -> T tag dim b -> (a -> b -> c) -> io
withExtra checkForAll genA genB test =
   checkForAll (mapGen (\_ b -> flip (,) b <$> genA) genB) (uncurry test)


mapGen ::
   (Integer -> a -> QC.Gen b) ->
   T tag dim a -> T tag dim b
mapGen f (Cons gen) =
   Cons $ do
      maxElem <- MRWS.ask
      mapSnd (Logic.liftGen . f maxElem =<<) <$> gen

mapGenDim ::
   (Integer -> Int -> a -> QC.Gen b) ->
   T tag dim a -> T tag dim b
mapGenDim f (Cons gen) =
   Cons $ do
      maxElem <- MRWS.ask
      (maxDim, _matchMode) <- MT.lift $ Logic.M MRWS.ask
      mapSnd (Logic.liftGen . f maxElem maxDim =<<) <$> gen


combine ::
   (forall s.
    TaggedVariables s dimF -> TaggedVariables s dimA ->
    (TaggedVariables s dimB, Logic.System s)) ->
   T tag dimF (a -> b) ->
   T tag dimA a ->
   T tag dimB b
combine combineDim (Cons genF) (Cons genA) =
   Cons $ do
      (dimF,f) <- genF
      (dimA,a) <- genA
      let (dimB, constraint) = combineDim dimF dimA
      MRWS.tell constraint
      return (dimB, f <*> a)


type Scalar tag = T tag ()

scalar :: (Class.Floating a) => Scalar a a
scalar =
   Cons $ do
      maxElem <- MRWS.ask
      return ((), Logic.liftGen $ Util.genElement maxElem)

(<.*.>) ::
   (Dim size, Eq size) =>
   Vector tag size (a -> b) ->
   Vector tag size a ->
   Scalar tag b
(<.*.>) = combine (\dimF dimA -> ((), dimF=!=dimA))


queryZeroInt :: Logic.Variable s Int -> Logic.M s ZeroInt
queryZeroInt var = zeroInt <$> Logic.query var

type Vector tag size = T tag (Variable size)

vectorDim :: Vector a Int ZeroInt
vectorDim =
   Cons $ do
      dim <- newVariable
      return (dim, queryZeroInt dim)

vector :: (Class.Floating a) => Vector a Int (Vector.Vector ZeroInt a)
vector = mapGen Util.genArray vectorDim

vectorReal ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Vector a Int (Vector.Vector ZeroInt ar)
vectorReal = mapGen Util.genArray vectorDim

(<.*|>) ::
   (Dim height, Eq height) =>
   Vector tag height (a -> b) ->
   Matrix tag height width a ->
   Vector tag width b
(<.*|>) = combine (\dim (height,width) -> (width, dim=!=height))

(<|*.>) ::
   (Dim width, Eq width) =>
   Matrix tag height width (a -> b) ->
   Vector tag width a ->
   Vector tag height b
(<|*.>) = combine (\(height,width) dim -> (height, width=!=dim))

(<.=.>) ::
   (Dim size, Eq size) =>
   Vector tag size (a -> b) ->
   Vector tag size a ->
   Vector tag size b
(<.=.>) = combine (\sizeF sizeA -> (sizeF, sizeF=!=sizeA))


type Matrix tag height width = T tag (Variable height, Variable width)

matrixDims :: Matrix a Int Int (ZeroInt, ZeroInt)
matrixDims =
   Cons $ do
      dims <- liftA2 (,) newVariable newVariable
      return (dims, AppHT.mapPair (queryZeroInt,queryZeroInt) dims)

matrix ::
   (Class.Floating a) => Matrix a Int Int (Matrix.General ZeroInt ZeroInt a)
matrix =
   flip mapGen matrixDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.general order) dims


squareDim :: Matrix a Int Int ZeroInt
squareDim =
   Cons $ do
      dim <- newVariable
      return ((dim,dim), queryZeroInt dim)

squareShaped ::
   (Shape.C sh, Class.Floating a) =>
   (MatrixShape.Order -> ZeroInt -> sh) -> Matrix a Int Int (Array sh a)
squareShaped shape =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      Util.genArray maxElem $ shape order size

square :: (Class.Floating a) => Matrix a Int Int (Square.Square ZeroInt a)
square = squareShaped MatrixShape.square

squareCond ::
   (Class.Floating a) =>
   (Square.Square ZeroInt a -> Bool) ->
   Matrix a Int Int (Square.Square ZeroInt a)
squareCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      Util.genArray maxElem (MatrixShape.square order size)
         `QC.suchThat`
         cond

invertible ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Square.Square ZeroInt a)
invertible = squareCond Util.invertible

diagonal ::
   (Class.Floating a) => Matrix a Int Int (Triangular.Diagonal ZeroInt a)
diagonal = squareShaped MatrixShape.diagonal

identity ::
   (MatrixShape.Content lo, MatrixShape.Content up, Class.Floating a) =>
   Matrix a Int Int (Triangular.Triangular lo MatrixShape.Unit up ZeroInt a)
identity =
   flip mapGen squareDim $ \ _maxElem size -> do
      order <- Util.genOrder
      return $ Triangular.identity order size

triangularCond ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   (Triangular.Triangular lo diag up ZeroInt a -> Bool) ->
   Matrix a Int Int (Triangular.Triangular lo diag up ZeroInt a)
triangularCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      genTriangularArray maxElem
         (MatrixShape.Triangular
            MatrixShape.autoDiag MatrixShape.autoUplo order size)
         `QC.suchThat`
         cond

triangular ::
   (MatrixShape.Content up, MatrixShape.Content lo, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   Matrix a Int Int (Triangular.Triangular lo diag up ZeroInt a)
triangular = triangularCond (const True)


newtype GenTriangularDiag lo up a diag =
   GenTriangularDiag {
      runGenTriangularDiag ::
         MatrixShape.Triangular lo diag up ZeroInt ->
         QC.Gen (Triangular.Triangular lo diag up ZeroInt a)
   }

genTriangularArray ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Class.Floating a) =>
   Integer ->
   MatrixShape.Triangular lo diag up ZeroInt ->
   QC.Gen (Triangular.Triangular lo diag up ZeroInt a)
genTriangularArray maxElem =
   runGenTriangularDiag $
   MatrixShape.switchTriDiag
      (GenTriangularDiag $ \shape ->
         Util.genArrayExtraDiag maxElem shape (const $ return one))
      (GenTriangularDiag $ Util.genArray maxElem)


tallDims :: Matrix a Int Int (ZeroInt, ZeroInt)
tallDims =
   Cons $ do
      height <- newVariable
      width  <- newVariable
      MRWS.tell $  width <!= height
      return ((height,width),
              liftA2 (,) (queryZeroInt height) (queryZeroInt width))

tall ::
   (Class.Floating a) =>
   Matrix a Int Int (Matrix.Tall ZeroInt ZeroInt a)
tall =
   flip mapGen tallDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem $ uncurry (MatrixShape.tall order) dims

fullRankTall ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Matrix.Tall ZeroInt ZeroInt a)
fullRankTall =
   flip mapGen tallDims $ \maxElem dims -> do
      order <- Util.genOrder
      Util.genArray maxElem (uncurry (MatrixShape.tall order) dims)
         `QC.suchThat` Util.fullRankTall


wide ::
   (Class.Floating a) =>
   Matrix a Int Int (Matrix.Wide ZeroInt ZeroInt a)
wide = Matrix.transpose <$> transpose tall

fullRankWide ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Matrix.Wide ZeroInt ZeroInt a)
fullRankWide = Matrix.transpose <$> transpose fullRankTall


hermitian ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Matrix a Int Int (Hermitian ZeroInt a)
hermitian = hermitianCond (const True)

hermitianCond ::
   (Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   (Hermitian ZeroInt a -> Bool) ->
   Matrix a Int Int (Hermitian ZeroInt a)
hermitianCond cond =
   flip mapGen squareDim $ \maxElem size -> do
      order <- Util.genOrder
      let shape = MatrixShape.hermitian order size
      (Util.genArrayExtraDiag maxElem shape
          (const $ fromReal <$> Util.genReal maxElem))
         `QC.suchThat` cond


{-
There cannot be a pure/point function.
-}
(<|*|>) ::
   (Dim fuse, Eq fuse) =>
   Matrix tag height fuse (a -> b) ->
   Matrix tag fuse width a ->
   Matrix tag height width b
(<|*|>) =
   combine (\(height,fuseF) (fuseA,width) -> ((height,width), fuseF=!=fuseA))

transpose ::
   Matrix tag height width a ->
   Matrix tag width height a
transpose (Cons gen) = Cons $ mapFst swap <$> gen

(<|\|>) ::
   (Dim height, Eq height) =>
   Matrix tag height width (a -> b) ->
   Matrix tag height nrhs a ->
   Matrix tag width nrhs b
(<|\|>) a b = transpose a <|*|> b

(<***>) ::
   Vector tag height (a -> b) ->
   Vector tag width a ->
   Matrix tag height width b
(<***>) = combine (\height width -> ((height,width), mempty))


(<|=|>) ::
   (Dim height, Eq height) =>
   (Dim width, Eq width) =>
   Matrix tag height width (a -> b) ->
   Matrix tag height width a ->
   Matrix tag height width b
(<|=|>) =
   combine $ \(heightF,widthF) (heightA,widthA) ->
      ((heightF,widthF), heightF=!=heightA <> widthF=!=widthA)


(!+!) ::
   Logic.Variable s dimA ->
   Logic.Variable s dimB ->
   RWST r (Logic.System s) () (Logic.M s) (Logic.Variable s (dimA :+: dimB))
a!+!b = do
   c <- newVariable
   MRWS.tell $ (a Logic.!+! b) c
   return c

(<===>) ::
   (Dim width, Eq width) =>
   Matrix tag heightA width (a -> b) ->
   Matrix tag heightB width a ->
   Matrix tag (heightA:+:heightB) width b
(<===>) (Cons genF) (Cons genA) =
   Cons $ do
      ((heightF,widthF),f) <- genF
      ((heightA,widthA),a) <- genA
      MRWS.tell $ widthF=!=widthA
      heightB <- heightF!+!heightA
      return ((heightB,widthF), f <*> a)

(<|||>) ::
   (Dim height, Eq height) =>
   Matrix tag height widthA (a -> b) ->
   Matrix tag height widthB a ->
   Matrix tag height (widthA:+:widthB) b
(<|||>) f a = transpose $ transpose f <===> transpose a



stack3 ::
   (Dim heightA, Eq heightA) =>
   (Dim widthB, Eq widthB) =>
   Matrix tag heightA widthA a ->
   Matrix tag heightA widthB b ->
   Matrix tag heightB widthB c ->
   Matrix tag (heightA:+:heightB) (widthA:+:widthB) (a,b,c)
stack3 (Cons genF) (Cons genA) (Cons genB) =
   Cons $ do
      ((heightF,widthF),f) <- genF
      ((heightA,widthA),a) <- genA
      ((heightB,widthB),b) <- genB
      MRWS.tell $  heightF=!=heightA  <>  widthA=!=widthB
      heightC <- heightF!+!heightB
      widthC <- widthF!+!widthB
      return ((heightC,widthC), (,,) <$> f <*> a <*> b)


infixl 4 <.*.>, <.*|>, <|*.>, <|*|>, <|\|>, <***>, <.=.>, <|=|>, <===>, <|||>