-- | This module define a "kind-checking" pass. This requires some explanation, since we don't have
--     an *explicit* notion of kinds in Covenant:
--
--     With the addition of type constructors for datatypes into ValT comes a new set of things that can
--     "go wrong". In particular:
--       - Someone may try to use a type constructor which is not defined anywhere
--       - A type may be applied to an incorrect number of arguments
--       - The "count" - the number of bound tyvars in the `ValT.Datatype` representation - may be incorrect (i.e. inconsistent with the count in the declaration)
--
--     The checks to detect these errors are entirely independent from the checks performed during typechecking or renaming, so we do them in a separate pass.
module Covenant.Internal.KindCheck
  ( checkDataDecls,
    checkValT,
    KindCheckError (..),
    EncodingArgErr (..),
    cycleCheck,
    checkEncodingArgs,
  )
where

import Control.Monad (unless)
import Control.Monad.Except (ExceptT, MonadError (throwError), runExceptT)
import Control.Monad.Reader
  ( MonadReader (local),
    ReaderT (ReaderT),
    asks,
    runReaderT,
  )
import Covenant.Data (everythingOf)
import Covenant.Index (Count, intCount)
import Covenant.Internal.Strategy
  ( DataEncoding (SOP),
  )
import Covenant.Internal.Type
  ( AbstractTy,
    CompT (CompT),
    CompTBody (CompTBody),
    Constructor (Constructor),
    DataDeclaration (DataDeclaration, OpaqueData),
    TyName,
    ValT (Abstraction, BuiltinFlat, Datatype, ThunkT),
    checkStrategy,
    datatype,
  )
import Data.Foldable (traverse_)
import Data.Functor.Identity (Identity, runIdentity)
import Data.Kind (Type)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as M
import Data.Maybe (mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Vector (Vector)
import Data.Vector qualified as V
import Optics.Core (A_Lens, LabelOptic (labelOptic), folded, lens, preview, review, to, toListOf, view, (%))

{- TODO: Explicitly separate the kind checker into two check functions:
     - One which kind checks `ValT`s to ensure:
       1. All TyCons in the ValT exist
       2. All TyCons in the ValT have the correct arity

     - One which checks *datatype declarations* to ensure:
       1. Everything satisfies the above ValT checks
       2. No thunk arguments to ctors!
       3. No mutual recursion (cycles)
-}

data KindCheckError
  = UnknownType TyName
  | IncorrectNumArgs TyName (Count "tyvar") (Vector (ValT AbstractTy)) -- first is expected (from the decl), second is actual
  | ThunkConstructorArg (CompT AbstractTy) -- no polymorphic function args to ctors
  | MutualRecursionDetected (Set TyName)
  | InvalidStrategy TyName
  | EncodingMismatch (EncodingArgErr AbstractTy)
  deriving stock (Int -> KindCheckError -> ShowS
[KindCheckError] -> ShowS
KindCheckError -> String
(Int -> KindCheckError -> ShowS)
-> (KindCheckError -> String)
-> ([KindCheckError] -> ShowS)
-> Show KindCheckError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KindCheckError -> ShowS
showsPrec :: Int -> KindCheckError -> ShowS
$cshow :: KindCheckError -> String
show :: KindCheckError -> String
$cshowList :: [KindCheckError] -> ShowS
showList :: [KindCheckError] -> ShowS
Show, KindCheckError -> KindCheckError -> Bool
(KindCheckError -> KindCheckError -> Bool)
-> (KindCheckError -> KindCheckError -> Bool) -> Eq KindCheckError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KindCheckError -> KindCheckError -> Bool
== :: KindCheckError -> KindCheckError -> Bool
$c/= :: KindCheckError -> KindCheckError -> Bool
/= :: KindCheckError -> KindCheckError -> Bool
Eq)

newtype KindCheckContext a = KindCheckContext (Map TyName (DataDeclaration a))

instance
  (k ~ A_Lens, a ~ Map TyName (DataDeclaration c), b ~ Map TyName (DataDeclaration c)) =>
  LabelOptic "kindCheckContext" k (KindCheckContext c) (KindCheckContext c) a b
  where
  {-# INLINEABLE labelOptic #-}
  labelOptic :: Optic k NoIx (KindCheckContext c) (KindCheckContext c) a b
labelOptic =
    (KindCheckContext c -> a)
-> (KindCheckContext c -> b -> KindCheckContext c)
-> Lens (KindCheckContext c) (KindCheckContext c) a b
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens
      (\(KindCheckContext Map TyName (DataDeclaration c)
x) -> a
Map TyName (DataDeclaration c)
x)
      (\(KindCheckContext Map TyName (DataDeclaration c)
_) b
x' -> Map TyName (DataDeclaration c) -> KindCheckContext c
forall a. Map TyName (DataDeclaration a) -> KindCheckContext a
KindCheckContext b
Map TyName (DataDeclaration c)
x')

newtype KindCheckM t a = KindCheckM (ReaderT (KindCheckContext t) (ExceptT KindCheckError Identity) a)
  deriving
    ((forall a b. (a -> b) -> KindCheckM t a -> KindCheckM t b)
-> (forall a b. a -> KindCheckM t b -> KindCheckM t a)
-> Functor (KindCheckM t)
forall a b. a -> KindCheckM t b -> KindCheckM t a
forall a b. (a -> b) -> KindCheckM t a -> KindCheckM t b
forall t a b. a -> KindCheckM t b -> KindCheckM t a
forall t a b. (a -> b) -> KindCheckM t a -> KindCheckM t b
forall (f :: Type -> Type).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall t a b. (a -> b) -> KindCheckM t a -> KindCheckM t b
fmap :: forall a b. (a -> b) -> KindCheckM t a -> KindCheckM t b
$c<$ :: forall t a b. a -> KindCheckM t b -> KindCheckM t a
<$ :: forall a b. a -> KindCheckM t b -> KindCheckM t a
Functor, Functor (KindCheckM t)
Functor (KindCheckM t) =>
(forall a. a -> KindCheckM t a)
-> (forall a b.
    KindCheckM t (a -> b) -> KindCheckM t a -> KindCheckM t b)
-> (forall a b c.
    (a -> b -> c)
    -> KindCheckM t a -> KindCheckM t b -> KindCheckM t c)
-> (forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b)
-> (forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t a)
-> Applicative (KindCheckM t)
forall t. Functor (KindCheckM t)
forall a. a -> KindCheckM t a
forall t a. a -> KindCheckM t a
forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t a
forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
forall a b.
KindCheckM t (a -> b) -> KindCheckM t a -> KindCheckM t b
forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t a
forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
forall t a b.
KindCheckM t (a -> b) -> KindCheckM t a -> KindCheckM t b
forall a b c.
(a -> b -> c) -> KindCheckM t a -> KindCheckM t b -> KindCheckM t c
forall t a b c.
(a -> b -> c) -> KindCheckM t a -> KindCheckM t b -> KindCheckM t c
forall (f :: Type -> Type).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall t a. a -> KindCheckM t a
pure :: forall a. a -> KindCheckM t a
$c<*> :: forall t a b.
KindCheckM t (a -> b) -> KindCheckM t a -> KindCheckM t b
<*> :: forall a b.
KindCheckM t (a -> b) -> KindCheckM t a -> KindCheckM t b
$cliftA2 :: forall t a b c.
(a -> b -> c) -> KindCheckM t a -> KindCheckM t b -> KindCheckM t c
liftA2 :: forall a b c.
(a -> b -> c) -> KindCheckM t a -> KindCheckM t b -> KindCheckM t c
$c*> :: forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
*> :: forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
$c<* :: forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t a
<* :: forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t a
Applicative, Applicative (KindCheckM t)
Applicative (KindCheckM t) =>
(forall a b.
 KindCheckM t a -> (a -> KindCheckM t b) -> KindCheckM t b)
-> (forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b)
-> (forall a. a -> KindCheckM t a)
-> Monad (KindCheckM t)
forall t. Applicative (KindCheckM t)
forall a. a -> KindCheckM t a
forall t a. a -> KindCheckM t a
forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
forall a b.
KindCheckM t a -> (a -> KindCheckM t b) -> KindCheckM t b
forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
forall t a b.
KindCheckM t a -> (a -> KindCheckM t b) -> KindCheckM t b
forall (m :: Type -> Type).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall t a b.
KindCheckM t a -> (a -> KindCheckM t b) -> KindCheckM t b
>>= :: forall a b.
KindCheckM t a -> (a -> KindCheckM t b) -> KindCheckM t b
$c>> :: forall t a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
>> :: forall a b. KindCheckM t a -> KindCheckM t b -> KindCheckM t b
$creturn :: forall t a. a -> KindCheckM t a
return :: forall a. a -> KindCheckM t a
Monad, MonadReader (KindCheckContext t), MonadError KindCheckError)
    via (ReaderT (KindCheckContext t) (ExceptT KindCheckError Identity))

runKindCheckM :: forall (t :: Type) (a :: Type). Map TyName (DataDeclaration t) -> KindCheckM t a -> Either KindCheckError a
runKindCheckM :: forall t a.
Map TyName (DataDeclaration t)
-> KindCheckM t a -> Either KindCheckError a
runKindCheckM Map TyName (DataDeclaration t)
dtypes (KindCheckM ReaderT (KindCheckContext t) (ExceptT KindCheckError Identity) a
act) = Identity (Either KindCheckError a) -> Either KindCheckError a
forall a. Identity a -> a
runIdentity (Identity (Either KindCheckError a) -> Either KindCheckError a)
-> (ExceptT KindCheckError Identity a
    -> Identity (Either KindCheckError a))
-> ExceptT KindCheckError Identity a
-> Either KindCheckError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT KindCheckError Identity a
-> Identity (Either KindCheckError a)
forall e (m :: Type -> Type) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT KindCheckError Identity a -> Either KindCheckError a)
-> ExceptT KindCheckError Identity a -> Either KindCheckError a
forall a b. (a -> b) -> a -> b
$ ReaderT (KindCheckContext t) (ExceptT KindCheckError Identity) a
-> KindCheckContext t -> ExceptT KindCheckError Identity a
forall r (m :: Type -> Type) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (KindCheckContext t) (ExceptT KindCheckError Identity) a
act (Map TyName (DataDeclaration t) -> KindCheckContext t
forall a. Map TyName (DataDeclaration a) -> KindCheckContext a
KindCheckContext Map TyName (DataDeclaration t)
dtypes)

lookupDeclaration :: forall (t :: Type). TyName -> KindCheckM t (DataDeclaration t)
lookupDeclaration :: forall t. TyName -> KindCheckM t (DataDeclaration t)
lookupDeclaration TyName
tn = do
  Map TyName (DataDeclaration t)
types <- (KindCheckContext t -> Map TyName (DataDeclaration t))
-> KindCheckM t (Map TyName (DataDeclaration t))
forall r (m :: Type -> Type) a. MonadReader r m => (r -> a) -> m a
asks (Optic'
  A_Lens NoIx (KindCheckContext t) (Map TyName (DataDeclaration t))
-> KindCheckContext t -> Map TyName (DataDeclaration t)
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view Optic'
  A_Lens NoIx (KindCheckContext t) (Map TyName (DataDeclaration t))
#kindCheckContext)
  case TyName
-> Map TyName (DataDeclaration t) -> Maybe (DataDeclaration t)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup TyName
tn Map TyName (DataDeclaration t)
types of
    Maybe (DataDeclaration t)
Nothing -> KindCheckError -> KindCheckM t (DataDeclaration t)
forall a. KindCheckError -> KindCheckM t a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (KindCheckError -> KindCheckM t (DataDeclaration t))
-> KindCheckError -> KindCheckM t (DataDeclaration t)
forall a b. (a -> b) -> a -> b
$ TyName -> KindCheckError
UnknownType TyName
tn
    Just DataDeclaration t
decl -> DataDeclaration t -> KindCheckM t (DataDeclaration t)
forall a. a -> KindCheckM t a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure DataDeclaration t
decl

{- This sanity checks datatype declarations using the criteria enumerated above.
-}

-- | Checks that all the data declarations in the argument \'make sense\'.
-- Specifically:
--
-- * The strategy declared for the datatype is valid for it
-- * There are no mutually recursive datatype declarations
-- * Constructor arguments are not thunks
-- * The number of type variables in any constructor isn't greater than we
-- expect
--
-- @since 1.1.0
checkDataDecls :: Map TyName (DataDeclaration AbstractTy) -> Either KindCheckError ()
checkDataDecls :: Map TyName (DataDeclaration AbstractTy) -> Either KindCheckError ()
checkDataDecls Map TyName (DataDeclaration AbstractTy)
decls = Map TyName (DataDeclaration AbstractTy)
-> KindCheckM AbstractTy () -> Either KindCheckError ()
forall t a.
Map TyName (DataDeclaration t)
-> KindCheckM t a -> Either KindCheckError a
runKindCheckM Map TyName (DataDeclaration AbstractTy)
decls (KindCheckM AbstractTy () -> Either KindCheckError ())
-> KindCheckM AbstractTy () -> Either KindCheckError ()
forall a b. (a -> b) -> a -> b
$ (DataDeclaration AbstractTy -> KindCheckM AbstractTy ())
-> [DataDeclaration AbstractTy] -> KindCheckM AbstractTy ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkDataDecl (Map TyName (DataDeclaration AbstractTy)
-> [DataDeclaration AbstractTy]
forall k a. Map k a -> [a]
M.elems Map TyName (DataDeclaration AbstractTy)
decls)

checkDataDecl :: DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkDataDecl :: DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkDataDecl OpaqueData {} = () -> KindCheckM AbstractTy ()
forall a. a -> KindCheckM AbstractTy a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
checkDataDecl decl :: DataDeclaration AbstractTy
decl@(DataDeclaration TyName
tn Count "tyvar"
_ Vector (Constructor AbstractTy)
ctors DataEncoding
_) = do
  Bool -> KindCheckM AbstractTy () -> KindCheckM AbstractTy ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (DataDeclaration AbstractTy -> Bool
forall a. DataDeclaration a -> Bool
checkStrategy DataDeclaration AbstractTy
decl) (KindCheckM AbstractTy () -> KindCheckM AbstractTy ())
-> KindCheckM AbstractTy () -> KindCheckM AbstractTy ()
forall a b. (a -> b) -> a -> b
$ KindCheckError -> KindCheckM AbstractTy ()
forall a. KindCheckError -> KindCheckM AbstractTy a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TyName -> KindCheckError
InvalidStrategy TyName
tn)
  Set TyName
-> DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
forall a.
Ord a =>
Set TyName -> DataDeclaration a -> KindCheckM a ()
cycleCheck' Set TyName
forall a. Monoid a => a
mempty DataDeclaration AbstractTy
decl
  let allCtorArgs :: Vector (ValT AbstractTy)
allCtorArgs = Optic
  A_Lens
  NoIx
  (Constructor AbstractTy)
  (Constructor AbstractTy)
  (Vector (ValT AbstractTy))
  (Vector (ValT AbstractTy))
-> Constructor AbstractTy -> Vector (ValT AbstractTy)
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view Optic
  A_Lens
  NoIx
  (Constructor AbstractTy)
  (Constructor AbstractTy)
  (Vector (ValT AbstractTy))
  (Vector (ValT AbstractTy))
#constructorArgs (Constructor AbstractTy -> Vector (ValT AbstractTy))
-> Vector (Constructor AbstractTy) -> Vector (ValT AbstractTy)
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< Vector (Constructor AbstractTy)
ctors
  (ValT AbstractTy -> KindCheckM AbstractTy ())
-> Vector (ValT AbstractTy) -> KindCheckM AbstractTy ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds KindCheckMode
CheckDataDecl) Vector (ValT AbstractTy)
allCtorArgs
  DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkEncodingArgsInDataDecl DataDeclaration AbstractTy
decl

data KindCheckMode = CheckDataDecl | CheckValT
  deriving stock (Int -> KindCheckMode -> ShowS
[KindCheckMode] -> ShowS
KindCheckMode -> String
(Int -> KindCheckMode -> ShowS)
-> (KindCheckMode -> String)
-> ([KindCheckMode] -> ShowS)
-> Show KindCheckMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KindCheckMode -> ShowS
showsPrec :: Int -> KindCheckMode -> ShowS
$cshow :: KindCheckMode -> String
show :: KindCheckMode -> String
$cshowList :: [KindCheckMode] -> ShowS
showList :: [KindCheckMode] -> ShowS
Show, KindCheckMode -> KindCheckMode -> Bool
(KindCheckMode -> KindCheckMode -> Bool)
-> (KindCheckMode -> KindCheckMode -> Bool) -> Eq KindCheckMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KindCheckMode -> KindCheckMode -> Bool
== :: KindCheckMode -> KindCheckMode -> Bool
$c/= :: KindCheckMode -> KindCheckMode -> Bool
/= :: KindCheckMode -> KindCheckMode -> Bool
Eq, Eq KindCheckMode
Eq KindCheckMode =>
(KindCheckMode -> KindCheckMode -> Ordering)
-> (KindCheckMode -> KindCheckMode -> Bool)
-> (KindCheckMode -> KindCheckMode -> Bool)
-> (KindCheckMode -> KindCheckMode -> Bool)
-> (KindCheckMode -> KindCheckMode -> Bool)
-> (KindCheckMode -> KindCheckMode -> KindCheckMode)
-> (KindCheckMode -> KindCheckMode -> KindCheckMode)
-> Ord KindCheckMode
KindCheckMode -> KindCheckMode -> Bool
KindCheckMode -> KindCheckMode -> Ordering
KindCheckMode -> KindCheckMode -> KindCheckMode
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KindCheckMode -> KindCheckMode -> Ordering
compare :: KindCheckMode -> KindCheckMode -> Ordering
$c< :: KindCheckMode -> KindCheckMode -> Bool
< :: KindCheckMode -> KindCheckMode -> Bool
$c<= :: KindCheckMode -> KindCheckMode -> Bool
<= :: KindCheckMode -> KindCheckMode -> Bool
$c> :: KindCheckMode -> KindCheckMode -> Bool
> :: KindCheckMode -> KindCheckMode -> Bool
$c>= :: KindCheckMode -> KindCheckMode -> Bool
>= :: KindCheckMode -> KindCheckMode -> Bool
$cmax :: KindCheckMode -> KindCheckMode -> KindCheckMode
max :: KindCheckMode -> KindCheckMode -> KindCheckMode
$cmin :: KindCheckMode -> KindCheckMode -> KindCheckMode
min :: KindCheckMode -> KindCheckMode -> KindCheckMode
Ord)

-- This isn't really a "kind checker" in the normal sense and just checks that none of the three failure conditions above obtain
checkKinds :: KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds :: KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds KindCheckMode
mode = \case
  Abstraction AbstractTy
_ -> () -> KindCheckM AbstractTy ()
forall a. a -> KindCheckM AbstractTy a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
  ThunkT compT :: CompT AbstractTy
compT@(CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT AbstractTy)
nev))
    | KindCheckMode
mode KindCheckMode -> KindCheckMode -> Bool
forall a. Eq a => a -> a -> Bool
== KindCheckMode
CheckDataDecl -> KindCheckError -> KindCheckM AbstractTy ()
forall a. KindCheckError -> KindCheckM AbstractTy a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (KindCheckError -> KindCheckM AbstractTy ())
-> KindCheckError -> KindCheckM AbstractTy ()
forall a b. (a -> b) -> a -> b
$ CompT AbstractTy -> KindCheckError
ThunkConstructorArg CompT AbstractTy
compT
    | Bool
otherwise -> (ValT AbstractTy -> KindCheckM AbstractTy ())
-> NonEmptyVector (ValT AbstractTy) -> KindCheckM AbstractTy ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds KindCheckMode
mode) NonEmptyVector (ValT AbstractTy)
nev
  BuiltinFlat {} -> () -> KindCheckM AbstractTy ()
forall a. a -> KindCheckM AbstractTy a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
  Datatype TyName
tn Vector (ValT AbstractTy)
args ->
    TyName -> KindCheckM AbstractTy (DataDeclaration AbstractTy)
forall t. TyName -> KindCheckM t (DataDeclaration t)
lookupDeclaration TyName
tn KindCheckM AbstractTy (DataDeclaration AbstractTy)
-> (DataDeclaration AbstractTy -> KindCheckM AbstractTy ())
-> KindCheckM AbstractTy ()
forall a b.
KindCheckM AbstractTy a
-> (a -> KindCheckM AbstractTy b) -> KindCheckM AbstractTy b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      OpaqueData {} -> () -> KindCheckM AbstractTy ()
forall a. a -> KindCheckM AbstractTy a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      DataDeclaration TyName
_ Count "tyvar"
numVars Vector (Constructor AbstractTy)
_ DataEncoding
_ -> do
        let numArgsActual :: Int
numArgsActual = Vector (ValT AbstractTy) -> Int
forall a. Vector a -> Int
V.length Vector (ValT AbstractTy)
args
            numArgsExpected :: Int
numArgsExpected = Optic' A_Prism NoIx Int (Count "tyvar") -> Count "tyvar" -> Int
forall k (is :: IxList) t b.
Is k A_Review =>
Optic' k is t b -> b -> t
review Optic' A_Prism NoIx Int (Count "tyvar")
forall (ofWhat :: Symbol). Prism' Int (Count ofWhat)
intCount Count "tyvar"
numVars
        Bool -> KindCheckM AbstractTy () -> KindCheckM AbstractTy ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Int
numArgsActual Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
numArgsExpected) (KindCheckM AbstractTy () -> KindCheckM AbstractTy ())
-> KindCheckM AbstractTy () -> KindCheckM AbstractTy ()
forall a b. (a -> b) -> a -> b
$ KindCheckError -> KindCheckM AbstractTy ()
forall a. KindCheckError -> KindCheckM AbstractTy a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TyName
-> Count "tyvar" -> Vector (ValT AbstractTy) -> KindCheckError
IncorrectNumArgs TyName
tn Count "tyvar"
numVars Vector (ValT AbstractTy)
args)
        (ValT AbstractTy -> KindCheckM AbstractTy ())
-> Vector (ValT AbstractTy) -> KindCheckM AbstractTy ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds KindCheckMode
mode) Vector (ValT AbstractTy)
args

-- | This is for checking type annotations in the ASG, *not* datatypes
checkValT :: Map TyName (DataDeclaration AbstractTy) -> ValT AbstractTy -> Maybe KindCheckError
checkValT :: Map TyName (DataDeclaration AbstractTy)
-> ValT AbstractTy -> Maybe KindCheckError
checkValT Map TyName (DataDeclaration AbstractTy)
dtypes = (KindCheckError -> Maybe KindCheckError)
-> (() -> Maybe KindCheckError)
-> Either KindCheckError ()
-> Maybe KindCheckError
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either KindCheckError -> Maybe KindCheckError
forall a. a -> Maybe a
Just (Maybe KindCheckError -> () -> Maybe KindCheckError
forall a b. a -> b -> a
const Maybe KindCheckError
forall a. Maybe a
Nothing) (Either KindCheckError () -> Maybe KindCheckError)
-> (ValT AbstractTy -> Either KindCheckError ())
-> ValT AbstractTy
-> Maybe KindCheckError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map TyName (DataDeclaration AbstractTy)
-> KindCheckM AbstractTy () -> Either KindCheckError ()
forall t a.
Map TyName (DataDeclaration t)
-> KindCheckM t a -> Either KindCheckError a
runKindCheckM Map TyName (DataDeclaration AbstractTy)
dtypes (KindCheckM AbstractTy () -> Either KindCheckError ())
-> (ValT AbstractTy -> KindCheckM AbstractTy ())
-> ValT AbstractTy
-> Either KindCheckError ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KindCheckMode -> ValT AbstractTy -> KindCheckM AbstractTy ()
checkKinds KindCheckMode
CheckValT

-- | Verifies that no types in the argument are mutually recursive.
--
-- @since 1.1.0
cycleCheck :: forall (a :: Type). (Ord a) => Map TyName (DataDeclaration a) -> Maybe KindCheckError
cycleCheck :: forall a.
Ord a =>
Map TyName (DataDeclaration a) -> Maybe KindCheckError
cycleCheck Map TyName (DataDeclaration a)
decls = (KindCheckError -> Maybe KindCheckError)
-> (() -> Maybe KindCheckError)
-> Either KindCheckError ()
-> Maybe KindCheckError
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either KindCheckError -> Maybe KindCheckError
forall a. a -> Maybe a
Just (Maybe KindCheckError -> () -> Maybe KindCheckError
forall a b. a -> b -> a
const Maybe KindCheckError
forall a. Maybe a
Nothing) (Either KindCheckError () -> Maybe KindCheckError)
-> Either KindCheckError () -> Maybe KindCheckError
forall a b. (a -> b) -> a -> b
$ Map TyName (DataDeclaration a)
-> KindCheckM a () -> Either KindCheckError ()
forall t a.
Map TyName (DataDeclaration t)
-> KindCheckM t a -> Either KindCheckError a
runKindCheckM Map TyName (DataDeclaration a)
decls KindCheckM a ()
go
  where
    go :: KindCheckM a ()
go =
      (KindCheckContext a -> KindCheckContext a)
-> KindCheckM a () -> KindCheckM a ()
forall a.
(KindCheckContext a -> KindCheckContext a)
-> KindCheckM a a -> KindCheckM a a
forall r (m :: Type -> Type) a.
MonadReader r m =>
(r -> r) -> m a -> m a
local (\KindCheckContext a
_ -> Map TyName (DataDeclaration a) -> KindCheckContext a
forall a. Map TyName (DataDeclaration a) -> KindCheckContext a
KindCheckContext Map TyName (DataDeclaration a)
decls) (KindCheckM a () -> KindCheckM a ())
-> KindCheckM a () -> KindCheckM a ()
forall a b. (a -> b) -> a -> b
$
        (DataDeclaration a -> KindCheckM a ())
-> [DataDeclaration a] -> KindCheckM a ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Set TyName -> DataDeclaration a -> KindCheckM a ()
forall a.
Ord a =>
Set TyName -> DataDeclaration a -> KindCheckM a ()
cycleCheck' Set TyName
forall a. Monoid a => a
mempty) ([DataDeclaration a] -> KindCheckM a ())
-> KindCheckM a [DataDeclaration a] -> KindCheckM a ()
forall (m :: Type -> Type) a b. Monad m => (a -> m b) -> m a -> m b
=<< (KindCheckContext a -> [DataDeclaration a])
-> KindCheckM a [DataDeclaration a]
forall r (m :: Type -> Type) a. MonadReader r m => (r -> a) -> m a
asks (Optic' A_Getter NoIx (KindCheckContext a) [DataDeclaration a]
-> KindCheckContext a -> [DataDeclaration a]
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view (Optic
  A_Lens
  NoIx
  (KindCheckContext a)
  (KindCheckContext a)
  (Map TyName (DataDeclaration a))
  (Map TyName (DataDeclaration a))
#kindCheckContext Optic
  A_Lens
  NoIx
  (KindCheckContext a)
  (KindCheckContext a)
  (Map TyName (DataDeclaration a))
  (Map TyName (DataDeclaration a))
-> Optic
     A_Getter
     NoIx
     (Map TyName (DataDeclaration a))
     (Map TyName (DataDeclaration a))
     [DataDeclaration a]
     [DataDeclaration a]
-> Optic' A_Getter NoIx (KindCheckContext a) [DataDeclaration a]
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% (Map TyName (DataDeclaration a) -> [DataDeclaration a])
-> Optic
     A_Getter
     NoIx
     (Map TyName (DataDeclaration a))
     (Map TyName (DataDeclaration a))
     [DataDeclaration a]
     [DataDeclaration a]
forall s a. (s -> a) -> Getter s a
to Map TyName (DataDeclaration a) -> [DataDeclaration a]
forall k a. Map k a -> [a]
M.elems))

{- This is a bit odd b/c we don't want to fail for auto-recursive types, so we need to be careful
   *not* to mark the current decl being examined as "visited" until we've "descended" into the dependencies
   (I think?)
-}
cycleCheck' :: forall (a :: Type). (Ord a) => Set TyName -> DataDeclaration a -> KindCheckM a ()
cycleCheck' :: forall a.
Ord a =>
Set TyName -> DataDeclaration a -> KindCheckM a ()
cycleCheck' Set TyName
_ OpaqueData {} = () -> KindCheckM a ()
forall a. a -> KindCheckM a a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
cycleCheck' Set TyName
visited (DataDeclaration TyName
tn Count "tyvar"
_ Vector (Constructor a)
ctors DataEncoding
_) = (Constructor a -> KindCheckM a ())
-> Vector (Constructor a) -> KindCheckM a ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Set TyName -> TyName -> Constructor a -> KindCheckM a ()
checkCtor Set TyName
visited TyName
tn) Vector (Constructor a)
ctors
  where
    checkCtor :: Set TyName -> TyName -> Constructor a -> KindCheckM a ()
    checkCtor :: Set TyName -> TyName -> Constructor a -> KindCheckM a ()
checkCtor Set TyName
vs TyName
tn' (Constructor ConstructorName
_ Vector (ValT a)
args) = do
      let allComponents :: [ValT a]
allComponents = Set (ValT a) -> [ValT a]
forall a. Set a -> [a]
Set.toList (Set (ValT a) -> [ValT a])
-> ([Set (ValT a)] -> Set (ValT a)) -> [Set (ValT a)] -> [ValT a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Set (ValT a)] -> Set (ValT a)
forall (f :: Type -> Type) a.
(Foldable f, Ord a) =>
f (Set a) -> Set a
Set.unions ([Set (ValT a)] -> [ValT a]) -> [Set (ValT a)] -> [ValT a]
forall a b. (a -> b) -> a -> b
$ ValT a -> Set (ValT a)
forall a. Ord a => ValT a -> Set (ValT a)
everythingOf (ValT a -> Set (ValT a)) -> [ValT a] -> [Set (ValT a)]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (ValT a) -> [ValT a]
forall a. Vector a -> [a]
V.toList Vector (ValT a)
args
          -- every type constructor in any part of a constructor arg, *except* the tycon of the decl
          -- we're examining, since autorecursion is fine/necessary
          allTyCons :: Set TyName
allTyCons = (TyName -> Bool) -> Set TyName -> Set TyName
forall a. (a -> Bool) -> Set a -> Set a
Set.filter (TyName -> TyName -> Bool
forall a. Eq a => a -> a -> Bool
/= TyName
tn') (Set TyName -> Set TyName)
-> ([ValT a] -> Set TyName) -> [ValT a] -> Set TyName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TyName] -> Set TyName
forall a. Ord a => [a] -> Set a
Set.fromList ([TyName] -> Set TyName)
-> ([ValT a] -> [TyName]) -> [ValT a] -> Set TyName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ValT a -> Maybe TyName) -> [ValT a] -> [TyName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (((TyName, Vector (ValT a)) -> TyName)
-> Maybe (TyName, Vector (ValT a)) -> Maybe TyName
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (TyName, Vector (ValT a)) -> TyName
forall a b. (a, b) -> a
fst (Maybe (TyName, Vector (ValT a)) -> Maybe TyName)
-> (ValT a -> Maybe (TyName, Vector (ValT a)))
-> ValT a
-> Maybe TyName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Optic' A_Prism NoIx (ValT a) (TyName, Vector (ValT a))
-> ValT a -> Maybe (TyName, Vector (ValT a))
forall k (is :: IxList) s a.
Is k An_AffineFold =>
Optic' k is s a -> s -> Maybe a
preview Optic' A_Prism NoIx (ValT a) (TyName, Vector (ValT a))
forall a. Prism' (ValT a) (TyName, Vector (ValT a))
datatype) ([ValT a] -> Set TyName) -> [ValT a] -> Set TyName
forall a b. (a -> b) -> a -> b
$ [ValT a]
allComponents
          alreadyVisitedArgTys :: Set TyName
alreadyVisitedArgTys = Set TyName -> Set TyName -> Set TyName
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection Set TyName
allTyCons Set TyName
vs
      Bool -> KindCheckM a () -> KindCheckM a ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (Set TyName -> Bool
forall a. Set a -> Bool
forall (t :: Type -> Type) a. Foldable t => t a -> Bool
null Set TyName
alreadyVisitedArgTys) (KindCheckM a () -> KindCheckM a ())
-> KindCheckM a () -> KindCheckM a ()
forall a b. (a -> b) -> a -> b
$ KindCheckError -> KindCheckM a ()
forall a. KindCheckError -> KindCheckM a a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (Set TyName -> KindCheckError
MutualRecursionDetected Set TyName
alreadyVisitedArgTys)
      let newVisited :: Set TyName
newVisited = TyName -> Set TyName -> Set TyName
forall a. Ord a => a -> Set a -> Set a
Set.insert TyName
tn' Set TyName
vs
      [DataDeclaration a]
nextRound <- (TyName -> KindCheckM a (DataDeclaration a))
-> [TyName] -> KindCheckM a [DataDeclaration a]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse TyName -> KindCheckM a (DataDeclaration a)
forall t. TyName -> KindCheckM t (DataDeclaration t)
lookupDeclaration (Set TyName -> [TyName]
forall a. Set a -> [a]
Set.toList Set TyName
allTyCons)
      (DataDeclaration a -> KindCheckM a ())
-> [DataDeclaration a] -> KindCheckM a ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Set TyName -> DataDeclaration a -> KindCheckM a ()
forall a.
Ord a =>
Set TyName -> DataDeclaration a -> KindCheckM a ()
cycleCheck' Set TyName
newVisited) [DataDeclaration a]
nextRound

{- Arguably the closest thing to a real kind checker in the module.

   Checks whether the arguments to type constructors (ValT 'Datatype's) conform with their encoding.

-}

-- First arg is the name of the type constructor w/ a bad argument, second arg is the bad argument.
data EncodingArgErr a = EncodingArgMismatch TyName (ValT a)
  deriving stock (Int -> EncodingArgErr a -> ShowS
[EncodingArgErr a] -> ShowS
EncodingArgErr a -> String
(Int -> EncodingArgErr a -> ShowS)
-> (EncodingArgErr a -> String)
-> ([EncodingArgErr a] -> ShowS)
-> Show (EncodingArgErr a)
forall a. Show a => Int -> EncodingArgErr a -> ShowS
forall a. Show a => [EncodingArgErr a] -> ShowS
forall a. Show a => EncodingArgErr a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> EncodingArgErr a -> ShowS
showsPrec :: Int -> EncodingArgErr a -> ShowS
$cshow :: forall a. Show a => EncodingArgErr a -> String
show :: EncodingArgErr a -> String
$cshowList :: forall a. Show a => [EncodingArgErr a] -> ShowS
showList :: [EncodingArgErr a] -> ShowS
Show, EncodingArgErr a -> EncodingArgErr a -> Bool
(EncodingArgErr a -> EncodingArgErr a -> Bool)
-> (EncodingArgErr a -> EncodingArgErr a -> Bool)
-> Eq (EncodingArgErr a)
forall a. Eq a => EncodingArgErr a -> EncodingArgErr a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => EncodingArgErr a -> EncodingArgErr a -> Bool
== :: EncodingArgErr a -> EncodingArgErr a -> Bool
$c/= :: forall a. Eq a => EncodingArgErr a -> EncodingArgErr a -> Bool
/= :: EncodingArgErr a -> EncodingArgErr a -> Bool
Eq)

-- | Verifies that a datatype (third argument) is valid according to its stated
-- encoding, as provided by the first two arguments (projection and metadata).
--
-- = Note
--
-- If the datatype being validated refers to other datatypes, we assume that
-- they exist in the metadata 'Map'. Thus, we must ensure this holds or the
-- check will fail.
--
-- @since 1.1.0
checkEncodingArgs ::
  forall (a :: Type) (info :: Type).
  (info -> DataEncoding) -> -- this lets us not care about whether we're doing this w/ a DataDeclaration or DatatypeInfo
  Map TyName info ->
  ValT a ->
  Either (EncodingArgErr a) ()
checkEncodingArgs :: forall a info.
(info -> DataEncoding)
-> Map TyName info -> ValT a -> Either (EncodingArgErr a) ()
checkEncodingArgs info -> DataEncoding
getEncoding Map TyName info
tyDict = \case
  Abstraction {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
  BuiltinFlat {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
  ThunkT (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT a)
args)) -> (ValT a -> Either (EncodingArgErr a) ())
-> NonEmptyVector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ValT a -> Either (EncodingArgErr a) ()
go NonEmptyVector (ValT a)
args
  Datatype TyName
tn Vector (ValT a)
args -> do
    let encoding :: DataEncoding
encoding = info -> DataEncoding
getEncoding (info -> DataEncoding) -> info -> DataEncoding
forall a b. (a -> b) -> a -> b
$ Map TyName info
tyDict Map TyName info -> TyName -> info
forall k a. Ord k => Map k a -> k -> a
M.! TyName
tn
    case DataEncoding
encoding of
      -- Might as well check all the way down
      DataEncoding
SOP -> do
        {- NOTE Sean 7/2/25: We are *temporarily* disallowing thunk arguments to SOPs to speed up development and
                             create consistency. We disallow Thunk arguments to constructors in datatype declarations,
                             and while we could very well allow them outside of those declarations, it creates a strange situation
                             where the same function might be safe/unsafe depending on whether it is used on a ValT inside of a data
                             declaration vs (e.g.) a type annotation in the ASG.

                             To remove this restriction, delete the `traverse_ isValidSOPArg args` line below
        -}
        (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ValT a -> Either (EncodingArgErr a) ()
go Vector (ValT a)
args
        (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (TyName -> ValT a -> Either (EncodingArgErr a) ()
isValidSOPArg TyName
tn) Vector (ValT a)
args

      -- Both explicit data encodings and builtins should be "morally data encoded"
      DataEncoding
_ -> do
        (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ValT a -> Either (EncodingArgErr a) ()
go Vector (ValT a)
args
        (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (TyName -> ValT a -> Either (EncodingArgErr a) ()
isValidDataArg TyName
tn) Vector (ValT a)
args
  where
    go :: ValT a -> Either (EncodingArgErr a) ()
    go :: ValT a -> Either (EncodingArgErr a) ()
go = (info -> DataEncoding)
-> Map TyName info -> ValT a -> Either (EncodingArgErr a) ()
forall a info.
(info -> DataEncoding)
-> Map TyName info -> ValT a -> Either (EncodingArgErr a) ()
checkEncodingArgs info -> DataEncoding
getEncoding Map TyName info
tyDict

    isValidDataArg :: TyName -> ValT a -> Either (EncodingArgErr a) ()
    isValidDataArg :: TyName -> ValT a -> Either (EncodingArgErr a) ()
isValidDataArg TyName
tn = \case
      Abstraction {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      BuiltinFlat {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      thunk :: ValT a
thunk@ThunkT {} -> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a. EncodingArgErr a -> Either (EncodingArgErr a) a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (EncodingArgErr a -> Either (EncodingArgErr a) ())
-> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a b. (a -> b) -> a -> b
$ TyName -> ValT a -> EncodingArgErr a
forall a. TyName -> ValT a -> EncodingArgErr a
EncodingArgMismatch TyName
tn ValT a
thunk
      dt :: ValT a
dt@(Datatype TyName
tn' Vector (ValT a)
args') -> do
        let encoding :: DataEncoding
encoding = info -> DataEncoding
getEncoding (info -> DataEncoding) -> info -> DataEncoding
forall a b. (a -> b) -> a -> b
$ Map TyName info
tyDict Map TyName info -> TyName -> info
forall k a. Ord k => Map k a -> k -> a
M.! TyName
tn'
        case DataEncoding
encoding of
          DataEncoding
SOP -> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a. EncodingArgErr a -> Either (EncodingArgErr a) a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (EncodingArgErr a -> Either (EncodingArgErr a) ())
-> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a b. (a -> b) -> a -> b
$ TyName -> ValT a -> EncodingArgErr a
forall a. TyName -> ValT a -> EncodingArgErr a
EncodingArgMismatch TyName
tn ValT a
dt
          DataEncoding
_ -> (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ ValT a -> Either (EncodingArgErr a) ()
go Vector (ValT a)
args'

    isValidSOPArg :: TyName -> ValT a -> Either (EncodingArgErr a) ()
    isValidSOPArg :: TyName -> ValT a -> Either (EncodingArgErr a) ()
isValidSOPArg TyName
tn = \case
      Abstraction {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      BuiltinFlat {} -> () -> Either (EncodingArgErr a) ()
forall a. a -> Either (EncodingArgErr a) a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
      thunk :: ValT a
thunk@ThunkT {} -> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a. EncodingArgErr a -> Either (EncodingArgErr a) a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (EncodingArgErr a -> Either (EncodingArgErr a) ())
-> EncodingArgErr a -> Either (EncodingArgErr a) ()
forall a b. (a -> b) -> a -> b
$ TyName -> ValT a -> EncodingArgErr a
forall a. TyName -> ValT a -> EncodingArgErr a
EncodingArgMismatch TyName
tn ValT a
thunk
      Datatype TyName
tn' Vector (ValT a)
args' -> (ValT a -> Either (EncodingArgErr a) ())
-> Vector (ValT a) -> Either (EncodingArgErr a) ()
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (TyName -> ValT a -> Either (EncodingArgErr a) ()
isValidSOPArg TyName
tn') Vector (ValT a)
args'

checkEncodingArgsInDataDecl :: DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkEncodingArgsInDataDecl :: DataDeclaration AbstractTy -> KindCheckM AbstractTy ()
checkEncodingArgsInDataDecl DataDeclaration AbstractTy
decl =
  (KindCheckContext AbstractTy
 -> Map TyName (DataDeclaration AbstractTy))
-> KindCheckM AbstractTy (Map TyName (DataDeclaration AbstractTy))
forall r (m :: Type -> Type) a. MonadReader r m => (r -> a) -> m a
asks (Optic'
  A_Lens
  NoIx
  (KindCheckContext AbstractTy)
  (Map TyName (DataDeclaration AbstractTy))
-> KindCheckContext AbstractTy
-> Map TyName (DataDeclaration AbstractTy)
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view Optic'
  A_Lens
  NoIx
  (KindCheckContext AbstractTy)
  (Map TyName (DataDeclaration AbstractTy))
#kindCheckContext) KindCheckM AbstractTy (Map TyName (DataDeclaration AbstractTy))
-> (Map TyName (DataDeclaration AbstractTy)
    -> KindCheckM AbstractTy ())
-> KindCheckM AbstractTy ()
forall a b.
KindCheckM AbstractTy a
-> (a -> KindCheckM AbstractTy b) -> KindCheckM AbstractTy b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Map TyName (DataDeclaration AbstractTy)
tyDict ->
    case (ValT AbstractTy -> Either (EncodingArgErr AbstractTy) ())
-> Vector (ValT AbstractTy)
-> Either (EncodingArgErr AbstractTy) (Vector ())
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> Vector a -> f (Vector b)
traverse ((DataDeclaration AbstractTy -> DataEncoding)
-> Map TyName (DataDeclaration AbstractTy)
-> ValT AbstractTy
-> Either (EncodingArgErr AbstractTy) ()
forall a info.
(info -> DataEncoding)
-> Map TyName info -> ValT a -> Either (EncodingArgErr a) ()
checkEncodingArgs (Optic' A_Lens NoIx (DataDeclaration AbstractTy) DataEncoding
-> DataDeclaration AbstractTy -> DataEncoding
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view Optic' A_Lens NoIx (DataDeclaration AbstractTy) DataEncoding
#datatypeEncoding) Map TyName (DataDeclaration AbstractTy)
tyDict) Vector (ValT AbstractTy)
allConstructorArgs of
      Left EncodingArgErr AbstractTy
encErr -> KindCheckError -> KindCheckM AbstractTy ()
forall a. KindCheckError -> KindCheckM AbstractTy a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (KindCheckError -> KindCheckM AbstractTy ())
-> KindCheckError -> KindCheckM AbstractTy ()
forall a b. (a -> b) -> a -> b
$ EncodingArgErr AbstractTy -> KindCheckError
EncodingMismatch EncodingArgErr AbstractTy
encErr
      Right Vector ()
_ -> () -> KindCheckM AbstractTy ()
forall a. a -> KindCheckM AbstractTy a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ()
  where
    allConstructorArgs :: Vector (ValT AbstractTy)
    allConstructorArgs :: Vector (ValT AbstractTy)
allConstructorArgs = [Vector (ValT AbstractTy)] -> Vector (ValT AbstractTy)
forall a. [Vector a] -> Vector a
V.concat ([Vector (ValT AbstractTy)] -> Vector (ValT AbstractTy))
-> [Vector (ValT AbstractTy)] -> Vector (ValT AbstractTy)
forall a b. (a -> b) -> a -> b
$ Optic'
  A_Fold NoIx (DataDeclaration AbstractTy) (Vector (ValT AbstractTy))
-> DataDeclaration AbstractTy -> [Vector (ValT AbstractTy)]
forall k (is :: IxList) s a.
Is k A_Fold =>
Optic' k is s a -> s -> [a]
toListOf (Optic
  A_Fold
  NoIx
  (DataDeclaration AbstractTy)
  (DataDeclaration AbstractTy)
  (Vector (Constructor AbstractTy))
  (Vector (Constructor AbstractTy))
#datatypeConstructors Optic
  A_Fold
  NoIx
  (DataDeclaration AbstractTy)
  (DataDeclaration AbstractTy)
  (Vector (Constructor AbstractTy))
  (Vector (Constructor AbstractTy))
-> Optic
     A_Fold
     NoIx
     (Vector (Constructor AbstractTy))
     (Vector (Constructor AbstractTy))
     (Constructor AbstractTy)
     (Constructor AbstractTy)
-> Optic
     A_Fold
     NoIx
     (DataDeclaration AbstractTy)
     (DataDeclaration AbstractTy)
     (Constructor AbstractTy)
     (Constructor AbstractTy)
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic
  A_Fold
  NoIx
  (Vector (Constructor AbstractTy))
  (Vector (Constructor AbstractTy))
  (Constructor AbstractTy)
  (Constructor AbstractTy)
forall (f :: Type -> Type) a. Foldable f => Fold (f a) a
folded Optic
  A_Fold
  NoIx
  (DataDeclaration AbstractTy)
  (DataDeclaration AbstractTy)
  (Constructor AbstractTy)
  (Constructor AbstractTy)
-> Optic
     A_Lens
     NoIx
     (Constructor AbstractTy)
     (Constructor AbstractTy)
     (Vector (ValT AbstractTy))
     (Vector (ValT AbstractTy))
-> Optic'
     A_Fold NoIx (DataDeclaration AbstractTy) (Vector (ValT AbstractTy))
forall k l m (is :: IxList) (js :: IxList) (ks :: IxList) s t u v a
       b.
(JoinKinds k l m, AppendIndices is js ks) =>
Optic k is s t u v -> Optic l js u v a b -> Optic m ks s t a b
% Optic
  A_Lens
  NoIx
  (Constructor AbstractTy)
  (Constructor AbstractTy)
  (Vector (ValT AbstractTy))
  (Vector (ValT AbstractTy))
#constructorArgs) DataDeclaration AbstractTy
decl