{-# LANGUAGE RecordWildCards #-}

module SoPSat.Internal.SolverMonad
where

import Control.Monad.Trans.State.Strict (
  StateT (..),
  evalStateT,
  get,
  gets,
  put,
 )

import Data.Map (Map)
import qualified Data.Map as M

import SoPSat.Internal.Range
import SoPSat.Internal.SoP (
  Product (..),
  SoP (..),
  Symbol (..),
 )
import SoPSat.Internal.Unify
import SoPSat.SoP
import qualified SoPSat.SoP as SoP

data (Ord f, Ord c) => State f c
  = State (Map (Atom f c) (Range f c)) [Unifier f c]
  deriving (Int -> State f c -> ShowS
[State f c] -> ShowS
State f c -> String
(Int -> State f c -> ShowS)
-> (State f c -> String)
-> ([State f c] -> ShowS)
-> Show (State f c)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall f c.
(Ord f, Ord c, Show f, Show c) =>
Int -> State f c -> ShowS
forall f c. (Ord f, Ord c, Show f, Show c) => [State f c] -> ShowS
forall f c. (Ord f, Ord c, Show f, Show c) => State f c -> String
$cshowsPrec :: forall f c.
(Ord f, Ord c, Show f, Show c) =>
Int -> State f c -> ShowS
showsPrec :: Int -> State f c -> ShowS
$cshow :: forall f c. (Ord f, Ord c, Show f, Show c) => State f c -> String
show :: State f c -> String
$cshowList :: forall f c. (Ord f, Ord c, Show f, Show c) => [State f c] -> ShowS
showList :: [State f c] -> ShowS
Show)

instance (Ord f, Ord c) => Semigroup (State f c) where
  (State Map (Atom f c) (Range f c)
r1 [Unifier f c]
u1) <> :: State f c -> State f c -> State f c
<> (State Map (Atom f c) (Range f c)
r2 [Unifier f c]
u2) = Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
forall f c.
(Ord f, Ord c) =>
Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
State (Map (Atom f c) (Range f c)
-> Map (Atom f c) (Range f c) -> Map (Atom f c) (Range f c)
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union Map (Atom f c) (Range f c)
r1 Map (Atom f c) (Range f c)
r2) ([Unifier f c]
u1 [Unifier f c] -> [Unifier f c] -> [Unifier f c]
forall a. [a] -> [a] -> [a]
++ [Unifier f c]
u2)

instance (Ord f, Ord c) => Monoid (State f c) where
  mempty :: State f c
mempty = Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
forall f c.
(Ord f, Ord c) =>
Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
State Map (Atom f c) (Range f c)
forall k a. Map k a
M.empty []

-- TODO: Change Maybe to some MonadError for better error indication
type SolverState f c = StateT (State f c) Maybe

maybeFail :: (MonadFail m) => Maybe a -> m a
maybeFail :: forall (m :: * -> *) a. MonadFail m => Maybe a -> m a
maybeFail (Just a
a) = a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
maybeFail Maybe a
Nothing = String -> m a
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
""

getRanges :: (Ord f, Ord c) => SolverState f c (Map (Atom f c) (Range f c))
getRanges :: forall f c.
(Ord f, Ord c) =>
SolverState f c (Map (Atom f c) (Range f c))
getRanges = (State f c -> Map (Atom f c) (Range f c))
-> StateT (State f c) Maybe (Map (Atom f c) (Range f c))
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (\(State Map (Atom f c) (Range f c)
rangeS [Unifier f c]
_) -> Map (Atom f c) (Range f c)
rangeS)

getRange :: (Ord f, Ord c) => Atom f c -> SolverState f c (Range f c)
getRange :: forall f c.
(Ord f, Ord c) =>
Atom f c -> SolverState f c (Range f c)
getRange Atom f c
c = Maybe (Range f c) -> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a. MonadFail m => Maybe a -> m a
maybeFail (Maybe (Range f c) -> StateT (State f c) Maybe (Range f c))
-> (Map (Atom f c) (Range f c) -> Maybe (Range f c))
-> Map (Atom f c) (Range f c)
-> StateT (State f c) Maybe (Range f c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Atom f c -> Map (Atom f c) (Range f c) -> Maybe (Range f c)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Atom f c
c (Map (Atom f c) (Range f c)
 -> StateT (State f c) Maybe (Range f c))
-> StateT (State f c) Maybe (Map (Atom f c) (Range f c))
-> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StateT (State f c) Maybe (Map (Atom f c) (Range f c))
forall f c.
(Ord f, Ord c) =>
SolverState f c (Map (Atom f c) (Range f c))
getRanges

getRangeSymbol :: (Ord f, Ord c) => Symbol f c -> SolverState f c (Range f c)
getRangeSymbol :: forall f c.
(Ord f, Ord c) =>
Symbol f c -> SolverState f c (Range f c)
getRangeSymbol (E SoP f c
b Product f c
p) = Maybe (Range f c) -> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a. MonadFail m => Maybe a -> m a
maybeFail (Maybe (Range f c) -> StateT (State f c) Maybe (Range f c))
-> StateT (State f c) Maybe (Maybe (Range f c))
-> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Range f c -> Range f c -> Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Range f c -> Range f c -> Maybe (Range f c)
rangeExp (Range f c -> Range f c -> Maybe (Range f c))
-> StateT (State f c) Maybe (Range f c)
-> StateT (State f c) Maybe (Range f c -> Maybe (Range f c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SoP f c -> StateT (State f c) Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
SoP f c -> SolverState f c (Range f c)
getRangeSoP SoP f c
b StateT (State f c) Maybe (Range f c -> Maybe (Range f c))
-> StateT (State f c) Maybe (Range f c)
-> StateT (State f c) Maybe (Maybe (Range f c))
forall a b.
StateT (State f c) Maybe (a -> b)
-> StateT (State f c) Maybe a -> StateT (State f c) Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Product f c -> StateT (State f c) Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Product f c -> SolverState f c (Range f c)
getRangeProduct Product f c
p
getRangeSymbol i :: Symbol f c
i@(I Integer
_) = Range f c -> StateT (State f c) Maybe (Range f c)
forall a. a -> StateT (State f c) Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return Range f c
range
 where
  bound :: Bound f c
bound = SoP f c -> Bound f c
forall f c. SoP f c -> Bound f c
Bound (Symbol f c -> SoP f c
forall f c a. ToSoP f c a => a -> SoP f c
toSoP Symbol f c
i)
  range :: Range f c
range = Bound f c -> Bound f c -> Range f c
forall f c. Bound f c -> Bound f c -> Range f c
Range Bound f c
bound Bound f c
bound
getRangeSymbol (A Atom f c
a) = Atom f c -> StateT (State f c) Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Atom f c -> SolverState f c (Range f c)
getRange Atom f c
a

getRangeProduct :: (Ord f, Ord c) => Product f c -> SolverState f c (Range f c)
getRangeProduct :: forall f c.
(Ord f, Ord c) =>
Product f c -> SolverState f c (Range f c)
getRangeProduct Product f c
p = Maybe (Range f c) -> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a. MonadFail m => Maybe a -> m a
maybeFail (Maybe (Range f c) -> StateT (State f c) Maybe (Range f c))
-> ([Range f c] -> Maybe (Range f c))
-> [Range f c]
-> StateT (State f c) Maybe (Range f c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (Range f c) -> Range f c -> Maybe (Range f c))
-> Maybe (Range f c) -> [Range f c] -> Maybe (Range f c)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (Range f c) -> Range f c -> Maybe (Range f c)
forall {f} {c}.
(Ord f, Ord c) =>
Maybe (Range f c) -> Range f c -> Maybe (Range f c)
rm Maybe (Range f c)
forall {f} {c}. Maybe (Range f c)
oneRange ([Range f c] -> StateT (State f c) Maybe (Range f c))
-> StateT (State f c) Maybe [Range f c]
-> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Symbol f c -> StateT (State f c) Maybe (Range f c))
-> [Symbol f c] -> StateT (State f c) Maybe [Range f c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Symbol f c -> StateT (State f c) Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Symbol f c -> SolverState f c (Range f c)
getRangeSymbol (Product f c -> [Symbol f c]
forall f c. Product f c -> [Symbol f c]
unP Product f c
p)
 where
  one :: Bound f c
one = SoP f c -> Bound f c
forall f c. SoP f c -> Bound f c
Bound (SoP f c -> Bound f c) -> SoP f c -> Bound f c
forall a b. (a -> b) -> a -> b
$ Integer -> SoP f c
forall f c. Integer -> SoP f c
SoP.int Integer
1
  oneRange :: Maybe (Range f c)
oneRange = Range f c -> Maybe (Range f c)
forall a. a -> Maybe a
Just (Bound f c -> Bound f c -> Range f c
forall f c. Bound f c -> Bound f c -> Range f c
Range Bound f c
forall {f} {c}. Bound f c
one Bound f c
forall {f} {c}. Bound f c
one)
  rm :: Maybe (Range f c) -> Range f c -> Maybe (Range f c)
rm Maybe (Range f c)
Nothing Range f c
_ = Maybe (Range f c)
forall a. Maybe a
Nothing
  rm (Just Range f c
a) Range f c
b = Range f c -> Range f c -> Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Range f c -> Range f c -> Maybe (Range f c)
rangeMul Range f c
a Range f c
b

getRangeSoP :: (Ord f, Ord c) => SoP f c -> SolverState f c (Range f c)
getRangeSoP :: forall f c.
(Ord f, Ord c) =>
SoP f c -> SolverState f c (Range f c)
getRangeSoP SoP f c
s = Maybe (Range f c) -> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a. MonadFail m => Maybe a -> m a
maybeFail (Maybe (Range f c) -> StateT (State f c) Maybe (Range f c))
-> ([Range f c] -> Maybe (Range f c))
-> [Range f c]
-> StateT (State f c) Maybe (Range f c)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe (Range f c) -> Range f c -> Maybe (Range f c))
-> Maybe (Range f c) -> [Range f c] -> Maybe (Range f c)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Maybe (Range f c) -> Range f c -> Maybe (Range f c)
forall {f} {c}.
(Ord f, Ord c) =>
Maybe (Range f c) -> Range f c -> Maybe (Range f c)
ra Maybe (Range f c)
forall {f} {c}. Maybe (Range f c)
zeroRange ([Range f c] -> StateT (State f c) Maybe (Range f c))
-> StateT (State f c) Maybe [Range f c]
-> StateT (State f c) Maybe (Range f c)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Product f c -> StateT (State f c) Maybe (Range f c))
-> [Product f c] -> StateT (State f c) Maybe [Range f c]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Product f c -> StateT (State f c) Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Product f c -> SolverState f c (Range f c)
getRangeProduct (SoP f c -> [Product f c]
forall f c. SoP f c -> [Product f c]
unS SoP f c
s)
 where
  zero :: Bound f c
zero = SoP f c -> Bound f c
forall f c. SoP f c -> Bound f c
Bound (SoP f c -> Bound f c) -> SoP f c -> Bound f c
forall a b. (a -> b) -> a -> b
$ Integer -> SoP f c
forall f c. Integer -> SoP f c
SoP.int Integer
0
  zeroRange :: Maybe (Range f c)
zeroRange = Range f c -> Maybe (Range f c)
forall a. a -> Maybe a
Just (Bound f c -> Bound f c -> Range f c
forall f c. Bound f c -> Bound f c -> Range f c
Range Bound f c
forall {f} {c}. Bound f c
zero Bound f c
forall {f} {c}. Bound f c
zero)
  ra :: Maybe (Range f c) -> Range f c -> Maybe (Range f c)
ra Maybe (Range f c)
Nothing Range f c
_ = Maybe (Range f c)
forall a. Maybe a
Nothing
  ra (Just Range f c
a) Range f c
b = Range f c -> Range f c -> Maybe (Range f c)
forall f c.
(Ord f, Ord c) =>
Range f c -> Range f c -> Maybe (Range f c)
rangeAdd Range f c
a Range f c
b

putRange :: (Ord f, Ord c) => Atom f c -> Range f c -> SolverState f c ()
putRange :: forall f c.
(Ord f, Ord c) =>
Atom f c -> Range f c -> SolverState f c ()
putRange Atom f c
symb range :: Range f c
range@Range{Bound f c
lower :: Bound f c
upper :: Bound f c
upper :: forall f c. Range f c -> Bound f c
lower :: forall f c. Range f c -> Bound f c
..} = do
  -- Anti-symmetry: 5 <= x ^ x <= 5 => x = 5
  case (Bound f c
lower Bound f c -> Bound f c -> Bool
forall a. Eq a => a -> a -> Bool
== Bound f c
upper, Bound f c
upper) of
    (Bool
True, Bound SoP f c
bound) -> [Unifier f c] -> SolverState f c ()
forall f c. (Ord f, Ord c) => [Unifier f c] -> SolverState f c ()
putUnifiers [Atom f c -> SoP f c -> Unifier f c
forall f c. Atom f c -> SoP f c -> Unifier f c
Subst Atom f c
symb (SoP f c -> SoP f c
forall f c a. ToSoP f c a => a -> SoP f c
toSoP SoP f c
bound)]
    (Bool, Bound f c)
_ -> () -> SolverState f c ()
forall a. a -> StateT (State f c) Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  (State Map (Atom f c) (Range f c)
rangeS [Unifier f c]
unifyS) <- StateT (State f c) Maybe (State f c)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let rangeSn :: Map (Atom f c) (Range f c)
rangeSn = Atom f c
-> Range f c
-> Map (Atom f c) (Range f c)
-> Map (Atom f c) (Range f c)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Atom f c
symb Range f c
range Map (Atom f c) (Range f c)
rangeS
  State f c -> SolverState f c ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
forall f c.
(Ord f, Ord c) =>
Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
State Map (Atom f c) (Range f c)
rangeSn [Unifier f c]
unifyS)

getUnifiers :: (Ord f, Ord c) => SolverState f c [Unifier f c]
getUnifiers :: forall f c. (Ord f, Ord c) => SolverState f c [Unifier f c]
getUnifiers = (State f c -> [Unifier f c])
-> StateT (State f c) Maybe [Unifier f c]
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (\(State Map (Atom f c) (Range f c)
_ [Unifier f c]
unifyS) -> [Unifier f c]
unifyS)

putUnifiers :: (Ord f, Ord c) => [Unifier f c] -> SolverState f c ()
putUnifiers :: forall f c. (Ord f, Ord c) => [Unifier f c] -> SolverState f c ()
putUnifiers [Unifier f c]
us = do
  (State Map (Atom f c) (Range f c)
rangeS [Unifier f c]
unifyS) <- StateT (State f c) Maybe (State f c)
forall (m :: * -> *) s. Monad m => StateT s m s
get
  State f c -> SolverState f c ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
forall f c.
(Ord f, Ord c) =>
Map (Atom f c) (Range f c) -> [Unifier f c] -> State f c
State Map (Atom f c) (Range f c)
rangeS ([Unifier f c] -> [Unifier f c] -> [Unifier f c]
forall f c.
(Ord f, Ord c) =>
[Unifier f c] -> [Unifier f c] -> [Unifier f c]
substsSubst [Unifier f c]
us [Unifier f c]
unifyS [Unifier f c] -> [Unifier f c] -> [Unifier f c]
forall a. [a] -> [a] -> [a]
++ [Unifier f c]
us))

-- | Puts a state to use during computations
withState :: (Ord f, Ord c) => State f c -> SolverState f c ()
withState :: forall f c. (Ord f, Ord c) => State f c -> SolverState f c ()
withState = State f c -> StateT (State f c) Maybe ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put

-- | Runs computation returning result and resulting state
runStatements :: (Ord f, Ord c) => SolverState f c a -> Maybe (a, State f c)
runStatements :: forall f c a.
(Ord f, Ord c) =>
SolverState f c a -> Maybe (a, State f c)
runStatements SolverState f c a
stmts = SolverState f c a -> State f c -> Maybe (a, State f c)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT SolverState f c a
stmts State f c
forall a. Monoid a => a
mempty

-- | Similar to @runStatements@ but does not return final state
evalStatements :: (Ord f, Ord c) => SolverState f c a -> Maybe a
evalStatements :: forall f c a. (Ord f, Ord c) => SolverState f c a -> Maybe a
evalStatements SolverState f c a
stmts = SolverState f c a -> State f c -> Maybe a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT SolverState f c a
stmts State f c
forall a. Monoid a => a
mempty