{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MonoLocalBinds #-}
{-|
   Equality-matching, implemented using a relational database
   (defined in 'Data.Equality.Matching.Database') according to the paper
   \"Relational E-Matching\" https://arxiv.org/abs/2108.02290.
 -}
module Data.Equality.Matching
    ( ematch
    , eGraphToDatabase
    , Match(..)

    -- * Compiling a Pattern to a Query
    , compileToQuery
    , VarsState(varNames), findVarName
    , userPatVars

    , module Data.Equality.Matching.Pattern
    )
    where

import Data.Maybe (mapMaybe)
import Data.Foldable (toList)
import Data.Containers.ListUtils

import Control.Monad
import Control.Monad.Trans.State.Strict

import qualified Data.Map.Strict    as M
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS

import Data.Equality.Graph
import Data.Equality.Graph.Lens
import Data.Equality.Matching.Database
import Data.Equality.Matching.Pattern
import Data.Coerce (coerce)

-- | Matching a pattern on an e-graph returns the e-class in which the pattern
-- was matched and an e-class substitution for every 'VariablePattern' in the pattern.
data Match = Match
    { Match -> Subst
matchSubst :: !Subst
    , Match -> Int
matchClassId :: {-# UNPACK #-} !ClassId
    }

-- TODO: Perhaps e-graph could carry database and rebuild it on rebuild

-- | Match a pattern 'Query', gotten from 'compileToQuery' on a 'Pattern',
-- against a 'Database', which is built from an 'EGraph' with 'eGraphToDatabase'.
--
-- Returns a list of matches, one 'Match' for each set of valid substitutions
-- for all variables and the equivalence class in which the pattern was matched.
--
-- 'ematch' takes a 'Database' instead of an 'EGraph' because the 'Database'
-- could be constructed only once and shared accross matching.
ematch :: Language l
       => Database l
       -> (Query l, Var {- query root var -})
       -> [Match]
ematch :: forall (l :: * -> *).
Language l =>
Database l -> (Query l, Var) -> [Match]
ematch Database l
db (Query l
q, Var
root) =
    let
        -- | Convert each substitution into a match by getting the class-id
        -- where we matched from the subst
        --
        -- If the substitution is empty there is no match
        f :: Subst -> Maybe Match
        f :: Subst -> Maybe Match
f Subst
s = if Subst -> Bool
nullSubst Subst
s
                then Maybe Match
forall a. Maybe a
Nothing
                else case Var -> Subst -> Maybe Int
lookupSubst Var
root Subst
s of
                  Maybe Int
Nothing -> String -> Maybe Match
forall a. HasCallStack => String -> a
error String
"how is root not in map?"
                  Just Int
found -> Match -> Maybe Match
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> Int -> Match
Match Subst
s Int
found)

     in (Subst -> Maybe Match) -> [Subst] -> [Match]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Subst -> Maybe Match
f (Database l -> Query l -> [Subst]
forall (l :: * -> *).
Language l =>
Database l -> Query l -> [Subst]
genericJoin Database l
db Query l
q)

-- | Convert an e-graph into a database
eGraphToDatabase :: Language l => EGraph a l -> Database l
eGraphToDatabase :: forall (l :: * -> *) a. Language l => EGraph a l -> Database l
eGraphToDatabase EGraph a l
egr = (ENode l -> Int -> Database l -> Database l)
-> Database l -> NodeMap l Int -> Database l
forall (l :: * -> *) a b.
Ord (l Int) =>
(ENode l -> a -> b -> b) -> b -> NodeMap l a -> b
foldrWithKeyNM' ENode l -> Int -> Database l -> Database l
forall (l :: * -> *).
Language l =>
ENode l -> Int -> Database l -> Database l
addENodeToDB (Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB Map (Operator l) IntTrie
forall a. Monoid a => a
mempty) (EGraph a l
egrEGraph a l -> Lens' (EGraph a l) (NodeMap l Int) -> NodeMap l Int
forall s a. s -> Lens' s a -> a
^.(NodeMap l Int -> f (NodeMap l Int))
-> EGraph a l -> f (EGraph a l)
forall a (l :: * -> *) (f :: * -> *).
Functor f =>
(NodeMap l Int -> f (NodeMap l Int))
-> EGraph a l -> f (EGraph a l)
Lens' (EGraph a l) (NodeMap l Int)
_memo)
  where

    -- Add an enode in an e-graph, given its class, to a database
    addENodeToDB :: Language l => ENode l -> ClassId -> Database l -> Database l
    addENodeToDB :: forall (l :: * -> *).
Language l =>
ENode l -> Int -> Database l -> Database l
addENodeToDB ENode l
enode Int
classid (DB Map (Operator l) IntTrie
m) =
        -- ROMES:TODO map find
        -- Insert or create a relation R_f(i1,i2,...,in) for lang in which 
        Map (Operator l) IntTrie -> Database l
forall (lang :: * -> *).
Map (Operator lang) IntTrie -> Database lang
DB (Map (Operator l) IntTrie -> Database l)
-> Map (Operator l) IntTrie -> Database l
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Operator l
-> Map (Operator l) IntTrie
-> Map (Operator l) IntTrie
forall k a.
Ord k =>
(Maybe a -> Maybe a) -> k -> Map k a -> Map k a
M.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Maybe IntTrie -> IntTrie
populate (Int
classidInt -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:ENode l -> [Int]
forall (l :: * -> *). Traversable l => ENode l -> [Int]
children ENode l
enode)) (ENode l -> Operator l
forall (l :: * -> *). Traversable l => ENode l -> Operator l
operator ENode l
enode) Map (Operator l) IntTrie
m

    -- Populate or create a triemap given the population D_x (ClassIds)
    -- Insert remaining ids population doesn't exist, recursively merge tries with remaining ids
    populate :: [ClassId] -> Maybe IntTrie -> IntTrie
    -- If trie map entry doesn't exist yet, populate an empty map with the remaining ids
    populate :: [Int] -> Maybe IntTrie -> IntTrie
populate []     Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie IntSet
forall a. Monoid a => a
mempty IntMap IntTrie
forall a. Monoid a => a
mempty
    populate (Int
x:[Int]
xs) Maybe IntTrie
Nothing = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (Int -> IntSet
IS.singleton Int
x) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ Int -> IntTrie -> IntMap IntTrie
forall a. Int -> a -> IntMap a
IM.singleton Int
x ([Int] -> Maybe IntTrie -> IntTrie
populate [Int]
xs Maybe IntTrie
forall a. Maybe a
Nothing)
    -- If trie map entry already exists, populate the existing map with the remaining ids
    populate []     (Just IntTrie
it)              = IntTrie
it
    populate (Int
x:[Int]
xs) (Just (MkIntTrie IntSet
k IntMap IntTrie
m)) = IntSet -> IntMap IntTrie -> IntTrie
MkIntTrie (Int
x Int -> IntSet -> IntSet
`IS.insert` IntSet
k) (IntMap IntTrie -> IntTrie) -> IntMap IntTrie -> IntTrie
forall a b. (a -> b) -> a -> b
$ (Maybe IntTrie -> Maybe IntTrie)
-> Int -> IntMap IntTrie -> IntMap IntTrie
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter (IntTrie -> Maybe IntTrie
forall a. a -> Maybe a
Just (IntTrie -> Maybe IntTrie)
-> (Maybe IntTrie -> IntTrie) -> Maybe IntTrie -> Maybe IntTrie
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Maybe IntTrie -> IntTrie
populate [Int]
xs) Int
x IntMap IntTrie
m
{-# INLINABLE eGraphToDatabase #-}


-- * Database related internals

-- | Auxiliary result in 'compileToQuery' algorithm
data AuxResult lang = {-# UNPACK #-} !Var :~ [Atom lang]


-- | Compiles a 'Pattern' to a 'Query' and returns the query root variable with
-- it.
-- The root variable's substitutions are the e-classes where the pattern
-- matched
compileToQuery :: (Traversable lang) => Pattern lang -> ((Query lang, Var), VarsState)
compileToQuery :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> ((Query lang, Var), VarsState)
compileToQuery (VariablePattern String
n) = (State VarsState (Query lang, Var)
 -> VarsState -> ((Query lang, Var), VarsState))
-> VarsState
-> State VarsState (Query lang, Var)
-> ((Query lang, Var), VarsState)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VarsState (Query lang, Var)
-> VarsState -> ((Query lang, Var), VarsState)
forall s a. State s a -> s -> (a, s)
runState VarsState
emptyVarsState (State VarsState (Query lang, Var)
 -> ((Query lang, Var), VarsState))
-> State VarsState (Query lang, Var)
-> ((Query lang, Var), VarsState)
forall a b. (a -> b) -> a -> b
$ do
  Var
v <- String -> State VarsState Var
getVarName String
n
  (Query lang, Var) -> State VarsState (Query lang, Var)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var -> Query lang
forall (lang :: * -> *). Var -> Query lang
SelectAllQuery Var
v, Var
v)
compileToQuery Pattern lang
pa =

  let (Var
root :~ [Atom lang]
atoms, VarsState
varsState) = State VarsState (AuxResult lang)
-> VarsState -> (AuxResult lang, VarsState)
forall s a. State s a -> s -> (a, s)
runState (Pattern lang -> State VarsState (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux Pattern lang
pa) VarsState
emptyVarsState
   in (([Var] -> [Atom lang] -> Query lang
forall (lang :: * -> *). [Var] -> [Atom lang] -> Query lang
Query ([Var] -> [Var]
nubVars ([Var] -> [Var]) -> [Var] -> [Var]
forall a b. (a -> b) -> a -> b
$ Var
rootVar -> [Var] -> [Var]
forall a. a -> [a] -> [a]
:VarsState -> [Var]
userPatVars VarsState
varsState) [Atom lang]
atoms, Var
root), VarsState
varsState)

    where

        aux :: (Traversable lang) => Pattern lang -> State VarsState (AuxResult lang)
        aux :: forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux (VariablePattern String
x) = do
          Var
v <- String -> State VarsState Var
getVarName String
x
          AuxResult lang -> State VarsState (AuxResult lang)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v Var -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). Var -> [Atom lang] -> AuxResult lang
:~ []) -- from definition in relational e-matching paper (needed for as base case for recursion)
        aux (NonVariablePattern lang (Pattern lang)
p) = do
            Var
v <- State VarsState Var
nextVar
            (lang (AuxResult lang) -> [AuxResult lang]
forall a. lang a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList -> [AuxResult lang]
auxs) <- (Pattern lang -> State VarsState (AuxResult lang))
-> lang (Pattern lang)
-> StateT VarsState Identity (lang (AuxResult lang))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse Pattern lang -> State VarsState (AuxResult lang)
forall (lang :: * -> *).
Traversable lang =>
Pattern lang -> State VarsState (AuxResult lang)
aux lang (Pattern lang)
p
            let boundVars :: [Var]
boundVars = (AuxResult lang -> Var) -> [AuxResult lang] -> [Var]
forall a b. (a -> b) -> [a] -> [b]
map (\(Var
b :~ [Atom lang]
_) -> Var
b) [AuxResult lang]
auxs
                atoms :: [Atom lang]
atoms     = [[Atom lang]] -> [Atom lang]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Atom lang]] -> [Atom lang]) -> [[Atom lang]] -> [Atom lang]
forall a b. (a -> b) -> a -> b
$ (AuxResult lang -> [Atom lang])
-> [AuxResult lang] -> [[Atom lang]]
forall a b. (a -> b) -> [a] -> [b]
map (\(Var
_ :~ [Atom lang]
a) -> [Atom lang]
a) [AuxResult lang]
auxs
                -- Number of bound vars should match number of children of this
                -- lang. We can traverse the pattern and replace sub-patterns with
                -- their corresponding bound variable
                p' :: lang Var
p' = State Int (lang Var) -> Int -> lang Var
forall s a. State s a -> s -> a
evalState (lang (Pattern lang) -> [Var] -> State Int (lang Var)
forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars lang (Pattern lang)
p [Var]
boundVars) Int
0
            AuxResult lang -> State VarsState (AuxResult lang)
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Var
v Var -> [Atom lang] -> AuxResult lang
forall (lang :: * -> *). Var -> [Atom lang] -> AuxResult lang
:~ (ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
forall (lang :: * -> *).
ClassIdOrVar -> lang ClassIdOrVar -> Atom lang
Atom (Var -> ClassIdOrVar
CVar Var
v) ((Var -> ClassIdOrVar) -> lang Var -> lang ClassIdOrVar
forall a b. (a -> b) -> lang a -> lang b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Var -> ClassIdOrVar
CVar lang Var
p')Atom lang -> [Atom lang] -> [Atom lang]
forall a. a -> [a] -> [a]
:[Atom lang]
atoms))
                where
                    -- State keeps track of the index of the variable we're
                    -- taking from the bound vars array
                    subPatsToVars :: Traversable lang => lang (Pattern lang) -> [Var] -> State Int (lang Var)
                    subPatsToVars :: forall (lang :: * -> *).
Traversable lang =>
lang (Pattern lang) -> [Var] -> State Int (lang Var)
subPatsToVars lang (Pattern lang)
p' [Var]
boundVars = (Pattern lang -> StateT Int Identity Var)
-> lang (Pattern lang) -> StateT Int Identity (lang Var)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> lang a -> f (lang b)
traverse (StateT Int Identity Var -> Pattern lang -> StateT Int Identity Var
forall a b. a -> b -> a
const (StateT Int Identity Var
 -> Pattern lang -> StateT Int Identity Var)
-> StateT Int Identity Var
-> Pattern lang
-> StateT Int Identity Var
forall a b. (a -> b) -> a -> b
$ ([Var]
boundVars [Var] -> Int -> Var
forall a. HasCallStack => [a] -> Int -> a
!!) (Int -> Var) -> StateT Int Identity Int -> StateT Int Identity Var
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (StateT Int Identity Int
forall (m :: * -> *) s. Monad m => StateT s m s
get StateT Int Identity Int
-> (Int -> StateT Int Identity Int) -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> (a -> StateT Int Identity b) -> StateT Int Identity b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
i -> (Int -> Int) -> StateT Int Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) StateT Int Identity ()
-> StateT Int Identity Int -> StateT Int Identity Int
forall a b.
StateT Int Identity a
-> StateT Int Identity b -> StateT Int Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> StateT Int Identity Int
forall a. a -> StateT Int Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i)) lang (Pattern lang)
p'
{-# INLINABLE compileToQuery #-}

--------------------------------------------------------------------------------
-- ** Vars utils for compileToQuery
--------------------------------------------------------------------------------

-- | Map user-given Variable names to internal 'Var's plus a counter
data VarsState = VarsState
  { VarsState -> Map String Var
varNames  :: !(M.Map String Var)
  , VarsState -> Int
nextVarId :: !Int
  }

-- | An empty 'VarsState'
emptyVarsState :: VarsState
emptyVarsState :: VarsState
emptyVarsState = Map String Var -> Int -> VarsState
VarsState Map String Var
forall a. Monoid a => a
mempty Int
0

-- | Compute the next internal 'Var' from the current 'VarNameMap'
nextVar :: State VarsState Var
nextVar :: State VarsState Var
nextVar = do
  Int
n <- (VarsState -> Int) -> StateT VarsState Identity Int
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets VarsState -> Int
nextVarId
  (VarsState -> VarsState) -> StateT VarsState Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (\VarsState
vs -> VarsState
vs{nextVarId = nextVarId vs +1})
  Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Var
MatchVar Int
n)

-- | Add a name to the 'VarNameMap' and get the resulting 'Var'.
getVarName :: String -> State VarsState Var
getVarName :: String -> State VarsState Var
getVarName String
s = do
  Map String Var
vm <- (VarsState -> Map String Var)
-> StateT VarsState Identity (Map String Var)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets VarsState -> Map String Var
varNames
  case String -> Map String Var -> Maybe Var
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
s Map String Var
vm of
    Maybe Var
Nothing -> do
      Var
n <- State VarsState Var
nextVar
      (VarsState -> VarsState) -> StateT VarsState Identity ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify' (\VarsState
vs -> VarsState
vs{varNames = M.insert s n (varNames vs)})
      Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
n
    Just Var
v ->
      Var -> State VarsState Var
forall a. a -> StateT VarsState Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return Var
v

findVarName :: VarsState -> String -> Var
findVarName :: VarsState -> String -> Var
findVarName VarsState
vs String
s = (VarsState -> Map String Var
varNames VarsState
vs) Map String Var -> String -> Var
forall k a. Ord k => Map k a -> k -> a
M.! String
s

-- | Return the variables given in a pattern by a user, from the 'VarsState'
userPatVars :: VarsState -> [Var]
userPatVars :: VarsState -> [Var]
userPatVars = Map String Var -> [Var]
forall k a. Map k a -> [a]
M.elems (Map String Var -> [Var])
-> (VarsState -> Map String Var) -> VarsState -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarsState -> Map String Var
varNames

-- | Deduplicate list of pattern Vars
nubVars :: [Var] -> [Var]
nubVars :: [Var] -> [Var]
nubVars = [Int] -> [Var]
forall a b. Coercible a b => a -> b
coerce ([Int] -> [Var]) -> ([Var] -> [Int]) -> [Var] -> [Var]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
nubInt ([Int] -> [Int]) -> ([Var] -> [Int]) -> [Var] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Var] -> [Int]
forall a b. Coercible a b => a -> b
coerce