{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MonoLocalBinds #-}
module Data.Equality.Matching
( ematch
, eGraphToDatabase
, Match(..)
, compileToQuery
, VarsState(varNames), findVarName
, userPatVars
, module Data.Equality.Matching.Pattern
)
where
import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils
import Control.Monad
import Control.Monad.Trans.State.Strict
import qualified Data.Map.Strict as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern
import Data.Coerce (coerce)
data Match = Match
{ Match -> Subst
matchSubst :: !Subst
, Match -> Int
matchClassId :: {-# UNPACK #-} !ClassId
}
ematch :: Language l
=> Database l
-> (Query l, Var )
-> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> (Query l, Var) -> [Match]
ematch Database l
db (Query l
q, Var
root) =
let
f :: Subst -> Maybe Match
f :: Subst -> Maybe Match
f Subst
s = if Subst -> Bool
nullSubst Subst
s
then Maybe Match
forall a. Maybe a
Nothing
else case Var -> Subst -> Maybe Int
lookupSubst Var
root Subst
s of
Maybe Int
Nothing -> String -> Maybe Match
forall a. HasCallStack => String -> a
error String
"how is root not in map?"
Just Int
found -> Match -> Maybe Match
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> Int -> Match
Match Subst
s Int
found)
in (Subst -> Maybe Match) -> [Subst] -> [Match]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (Database l -> Query l -> [Subst]
forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)
eGraphToDatabase :: Language l => EGraph a l -> Database l
eGraphToDatabase :: forall (l :: * -> *) a. Language l => EGraph a l -> Database l
eGraphToDatabase EGraph a l
egr = (ENode l -> Int -> Database l -> Database l)
-> Database l -> NodeMap l Int -> Database l
forall (l :: * -> *) a b.
Ord (l Int) =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> Database l -> Database l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> Database l -> Database l
addENodeToDB (Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB Map (Operator l) IntTrie
forall a. Monoid a => a
mempty) (EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (NodeMap l Int)
_memo)
where
addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> Database l -> Database l
addENodeToDB ENode l
enode Int
classid (DB Map (Operator l) IntTrie
m) =
Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB (Map (Operator l) IntTrie -> Database l)
-> Map (Operator l) IntTrie -> Database l
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Operator l
-> Map (Operator l) IntTrie
-> Map (Operator l) IntTrie
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Maybe IntTrie -> IntTrie
populate (Int
classidInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:ENode l -> [Int]
forall (l :: * -> *). Traversable l => ENode l -> [Int]
children ENode l
enode)) (ENode l -> Operator l
forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m
populate :: [ClassId] -> Maybe IntTrie -> IntTrie
populate :: [Int] -> Maybe IntTrie -> IntTrie
populate [] Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
forall a. Monoid a => a
mempty
populate (Int
x:[Int]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (Int -> IntSet
IS.singleton Int
x) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ Int -> IntTrie -> IntMap IntTrie
forall a. Int -> a -> IntMap a
IM.singleton Int
x ([Int] -> Maybe IntTrie -> IntTrie
populate [Int]
xs Maybe IntTrie
forall a. Maybe a
Nothing)
populate [] (Just IntTrie
it) = IntTrie
it
populate (Int
x:[Int]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (Int
x Int -> IntSet -> IntSet
`IS.insert` IntSet
k) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Int -> IntMap IntTrie -> IntMap IntTrie
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Maybe IntTrie -> IntTrie
populate [Int]
xs) Int
x IntMap IntTrie
m
{-# INLINABLE eGraphToDatabase #-}
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]
compileToQuery :: (Traversable lang) => Pattern lang -> ((Query lang, Var), VarsState)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> ((Query lang, Var), VarsState)
compileToQuery (VariablePattern String
n) = (State VarsState (Query lang, Var)
-> VarsState -> ((Query lang, Var), VarsState))
-> VarsState
-> State VarsState (Query lang, Var)
-> ((Query lang, Var), VarsState)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VarsState (Query lang, Var)
-> VarsState -> ((Query lang, Var), VarsState)
forall s a. State s a -> s -> (a, s)
runState VarsState
emptyVarsState (State VarsState (Query lang, Var)
-> ((Query lang, Var), VarsState))
-> State VarsState (Query lang, Var)
-> ((Query lang, Var), VarsState)
forall a b. (a -> b) -> a -> b
$ do
Var
v <- String -> State VarsState Var
getVarName String
n
(Query lang, Var) -> State VarsState (Query lang, Var)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var -> Query lang
forall (lang :: * -> *). Var -> Query lang
SelectAllQuery Var
v, Var
v)
compileToQuery Pattern lang
pa =
let (Var
root :~ [Atom lang]
atoms, VarsState
varsState) = State VarsState (AuxResult lang)
-> VarsState -> (AuxResult lang, VarsState)
forall s a. State s a -> s -> (a, s)
runState (Pattern lang -> State VarsState (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux Pattern lang
pa) VarsState
emptyVarsState
in (([Var] -> [Atom lang] -> Query lang
forall (lang :: * -> *). [Var] -> [Atom lang] -> Query lang
Query ([Var] -> [Var]
nubVars ([Var] -> [Var]) -> [Var] -> [Var]
forall a b. (a -> b) -> a -> b
$ Var
rootVar -> [Var] -> [Var]
forall a. a -> [a] -> [a]
:VarsState -> [Var]
userPatVars VarsState
varsState) [Atom lang]
atoms, Var
root), VarsState
varsState)
where
aux :: (Traversable lang) => Pattern lang -> State VarsState (AuxResult lang)
aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux (VariablePattern String
x) = do
Var
v <- String -> State VarsState Var
getVarName String
x
AuxResult lang -> State VarsState (AuxResult lang)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v Var -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). Var -> [Atom lang] -> AuxResult lang
:~ [])
aux (NonVariablePattern lang (Pattern lang)
p) = do
Var
v <- State VarsState Var
nextVar
(lang (AuxResult lang) -> [AuxResult lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- (Pattern lang -> State VarsState (AuxResult lang))
-> lang (Pattern lang)
-> StateT VarsState Identity (lang (AuxResult lang))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse Pattern lang -> State VarsState (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux lang (Pattern lang)
p
let boundVars :: [Var]
boundVars = (AuxResult lang -> Var) -> [AuxResult lang] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map (\(Var
b :~ [Atom lang]
_) -> Var
b) [AuxResult lang]
auxs
atoms :: [Atom lang]
atoms = [[Atom lang]] -> [Atom lang]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Atom lang]] -> [Atom lang]) -> [[Atom lang]] -> [Atom lang]
forall a b. (a -> b) -> a -> b
$ (AuxResult lang -> [Atom lang])
-> [AuxResult lang] -> [[Atom lang]]
forall a b. (a -> b) -> [a] -> [b]
map (\(Var
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
p' :: lang Var
p' = State Int (lang Var) -> Int -> lang Var
forall s a. State s a -> s -> a
evalState (lang (Pattern lang) -> [Var] -> State Int (lang Var)
forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars lang (Pattern lang)
p [Var]
boundVars) Int
0
AuxResult lang -> State VarsState (AuxResult lang)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v Var -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). Var -> [Atom lang] -> AuxResult lang
:~ (ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (Var -> ClassIdOrVar
CVar Var
v) ((Var -> ClassIdOrVar) -> lang Var -> lang ClassIdOrVar
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Var -> ClassIdOrVar
CVar lang Var
p')Atom lang -> [Atom lang] -> [Atom lang]
forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
where
subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars lang (Pattern lang)
p' [Var]
boundVars = (Pattern lang -> StateT Int Identity Var)
-> lang (Pattern lang) -> StateT Int Identity (lang Var)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse (StateT Int Identity Var -> Pattern lang -> StateT Int Identity Var
forall a b. a -> b -> a
const (StateT Int Identity Var
-> Pattern lang -> StateT Int Identity Var)
-> StateT Int Identity Var
-> Pattern lang
-> StateT Int Identity Var
forall a b. (a -> b) -> a -> b
$ ([Var]
boundVars [Var] -> Int -> Var
forall a. HasCallStack => [a] -> Int -> a
!!) (Int -> Var) -> StateT Int Identity Int -> StateT Int Identity Var
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StateT Int Identity Int
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT Int Identity Int
-> (Int -> StateT Int Identity Int) -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> (a -> StateT Int Identity b) -> StateT Int Identity b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
i -> (Int -> Int) -> StateT Int Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) StateT Int Identity ()
-> StateT Int Identity Int -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> StateT Int Identity b -> StateT Int Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> StateT Int Identity Int
forall a. a -> StateT Int Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i)) lang (Pattern lang)
p'
{-# INLINABLE compileToQuery #-}
data =
{ VarsState -> Map String Var
varNames :: !(M.Map String Var)
, VarsState -> Int
nextVarId :: !Int
}
emptyVarsState :: VarsState
= Map String Var -> Int -> VarsState
VarsState Map String Var
forall a. Monoid a => a
mempty Int
0
nextVar :: State VarsState Var
nextVar :: State VarsState Var
nextVar = do
Int
n <- (VarsState -> Int) -> StateT VarsState Identity Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets VarsState -> Int
nextVarId
(VarsState -> VarsState) -> StateT VarsState Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (\VarsState
vs -> VarsState
vs{nextVarId = nextVarId vs +1})
Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Var
MatchVar Int
n)
getVarName :: String -> State VarsState Var
getVarName :: String -> State VarsState Var
getVarName String
s = do
Map String Var
vm <- (VarsState -> Map String Var)
-> StateT VarsState Identity (Map String Var)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets VarsState -> Map String Var
varNames
case String -> Map String Var -> Maybe Var
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
s Map String Var
vm of
Maybe Var
Nothing -> do
Var
n <- State VarsState Var
nextVar
(VarsState -> VarsState) -> StateT VarsState Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (\VarsState
vs -> VarsState
vs{varNames = M.insert s n (varNames vs)})
Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
n
Just Var
v ->
Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v
findVarName :: VarsState -> String -> Var
findVarName :: VarsState -> String -> Var
findVarName VarsState
vs String
s = (VarsState -> Map String Var
varNames VarsState
vs) Map String Var -> String -> Var
forall k a. Ord k => Map k a -> k -> a
M.! String
s
userPatVars :: VarsState -> [Var]
userPatVars :: VarsState -> [Var]
userPatVars = Map String Var -> [Var]
forall k a. Map k a -> [a]
M.elems (Map String Var -> [Var])
-> (VarsState -> Map String Var) -> VarsState -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarsState -> Map String Var
varNames
nubVars :: [Var] -> [Var]
nubVars :: [Var] -> [Var]
nubVars = [Int] -> [Var]
forall a b. Coercible a b => a -> b
coerce ([Int] -> [Var]) -> ([Var] -> [Int]) -> [Var] -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
nubInt ([Int] -> [Int]) -> ([Var] -> [Int]) -> [Var] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> [Int]
forall a b. Coercible a b => a -> b
coerce