{-|
  Copyright  :  (C) 2012-2016, University of Twente
                    2021-2026, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Rewriting combinators and traversals
-}

module Clash.Rewrite.Combinators
  ( allR
  , (!->)
  , (>-!)
  , (>-!->)
  , (>->)
  , bottomupR
  , repeatR
  , topdownR
  , topdownFixR
  ) where

import           Control.DeepSeq             (deepseq)
import           Control.Monad               ((>=>))
import qualified Control.Monad.Writer        as Writer
import qualified Data.Monoid                 as Monoid

import           Clash.Core.Term             (Term (..), CoreContext (..), primArg, patIds)
import           Clash.Core.VarEnv
  (extendInScopeSet, extendInScopeSetList)
import           Clash.Rewrite.Types

-- | Apply a transformation on the subtrees of an term
allR
  :: forall m
   . Monad m
  => Transform m
  -- ^ The transformation to apply to the subtrees
  -> Transform m
allR :: forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Lam Id
v Term
e) =
  Id -> Term -> Term
Lam Id
v (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> Id -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is Id
v) (Id -> CoreContext
LamBody Id
vCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (TyLam TyVar
tv Term
e) =
  TyVar -> Term -> Term
TyLam TyVar
tv (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext (InScopeSet -> TyVar -> InScopeSet
forall a. InScopeSet -> Var a -> InScopeSet
extendInScopeSet InScopeSet
is TyVar
tv) (TyVar -> CoreContext
TyLamBody TyVar
tvCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (App Term
e1 Term
e2) = do
  Term
e1' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
AppFunCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e1
  Term
e2' <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (Maybe (Text, Int, Int) -> CoreContext
AppArg (Term -> Maybe (Text, Int, Int)
primArg Term
e1') CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e2
  Term -> m Term
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Term -> Term -> Term
App Term
e1' Term
e2')

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (TyApp Term
e Type
ty) =
  Term -> Type -> Term
TyApp (Term -> Type -> Term) -> m Term -> m (Type -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
TyAppCCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Term) -> m Type -> m Term
forall a b. m (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Cast Term
e Type
ty1 Type
ty2) =
  Term -> Type -> Type -> Term
Cast (Term -> Type -> Type -> Term)
-> m Term -> m (Type -> Type -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CastBodyCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e m (Type -> Type -> Term) -> m Type -> m (Type -> Term)
forall a b. m (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty1 m (Type -> Term) -> m Type -> m Term
forall a b. m (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty2

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Letrec [LetBinding]
xes Term
e) = do
  [LetBinding]
xes' <- (LetBinding -> m LetBinding) -> [LetBinding] -> m [LetBinding]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse LetBinding -> m LetBinding
rewriteBind [LetBinding]
xes
  Term
e'   <- Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' ([LetBinding] -> CoreContext
LetBody [LetBinding]
xesCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e
  Term -> m Term
forall a. a -> m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([LetBinding] -> Term -> Term
Letrec [LetBinding]
xes' Term
e')
 where
  bndrs :: [Id]
bndrs              = (LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes
  is' :: InScopeSet
is'                = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is ((LetBinding -> Id) -> [LetBinding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map LetBinding -> Id
forall a b. (a, b) -> a
fst [LetBinding]
xes)
  rewriteBind :: LetBinding -> m LetBinding
rewriteBind (Id
b,Term
e') = (Id
b,) (Term -> LetBinding) -> m Term -> m LetBinding
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Id -> [Id] -> CoreContext
LetBinding Id
b [Id]
bndrsCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e'

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Case Term
scrut Type
ty [Alt]
alts) =
  Term -> Type -> [Alt] -> Term
Case (Term -> Type -> [Alt] -> Term)
-> m Term -> m (Type -> [Alt] -> Term)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (CoreContext
CaseScrutCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
scrut
       m (Type -> [Alt] -> Term) -> m Type -> m ([Alt] -> Term)
forall a b. m (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Type -> m Type
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Type
ty
       m ([Alt] -> Term) -> m [Alt] -> m Term
forall a b. m (a -> b) -> m a -> m b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> (Alt -> m Alt) -> [Alt] -> m [Alt]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: Type -> Type) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse Alt -> m Alt
rewriteAlt [Alt]
alts
 where
  rewriteAlt :: Alt -> m Alt
rewriteAlt (Pat
p,Term
e) =
    let ([TyVar]
tvs,[Id]
ids) = Pat -> ([TyVar], [Id])
patIds Pat
p
        is' :: InScopeSet
is'       = InScopeSet -> [Id] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList (InScopeSet -> [TyVar] -> InScopeSet
forall a. InScopeSet -> [Var a] -> InScopeSet
extendInScopeSetList InScopeSet
is [TyVar]
tvs) [Id]
ids
    in  (Pat
p,) (Term -> Alt) -> m Term -> m Alt
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is' (Pat -> CoreContext
CaseAlt Pat
p CoreContext -> Context -> Context
forall a. a -> [a] -> [a]
: Context
c)) Term
e

allR Transform m
trans (TransformContext InScopeSet
is Context
c) (Tick TickInfo
sp Term
e) =
  TickInfo -> Term -> Term
Tick TickInfo
sp (Term -> Term) -> m Term -> m Term
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Transform m
trans (InScopeSet -> Context -> TransformContext
TransformContext InScopeSet
is (TickInfo -> CoreContext
TickC TickInfo
spCoreContext -> Context -> Context
forall a. a -> [a] -> [a]
:Context
c)) Term
e

allR Transform m
_ TransformContext
_ Term
tm = Term -> m Term
forall a. a -> m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
tm

infixr 6 >->
-- | Apply two transformations in succession
(>->) :: Monad m => Transform m -> Transform m -> Transform m
>-> :: forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
(>->) = \Transform m
r1 Transform m
r2 TransformContext
c -> Transform m
r1 TransformContext
c (Term -> m Term) -> (Term -> m Term) -> Term -> m Term
forall (m :: Type -> Type) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Transform m
r2 TransformContext
c
{-# INLINE (>->) #-}

infixr 6 >-!->
-- | Apply two transformations in succession, and perform a deepseq in between.
(>-!->) :: Monad m => Transform m -> Transform m -> Transform m
>-!-> :: forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
(>-!->) = \Transform m
r1 Transform m
r2 TransformContext
c Term
e -> do
  Term
e' <- Transform m
r1 TransformContext
c Term
e
  Term -> m Term -> m Term
forall a b. NFData a => a -> b -> b
deepseq Term
e' (Transform m
r2 TransformContext
c Term
e')
{-# INLINE (>-!->) #-}

{-
Note [topdown repeatR]
~~~~~~~~~~~~~~~~~~~~~~
In a topdown traversal we need to repeat the transformation r because
if r replaces a parent node with one of its children
we should still apply r to that child, before continuing with its children.

Example: topdownR (inlineBinders (\_ _ -> return True))
on:
> letrec
>   x = 1
> in letrec
>      y = 2
>    in f x y

inlineBinders would inline x and return:
> letrec
>   y = 2
> in f 1 y

Then we must repeat the transformation to let it also inline y.
-}

-- | Apply a transformation in a topdown traversal
topdownR :: Rewrite m -> Rewrite m
-- See Note [topdown repeatR]
topdownR :: forall m. Rewrite m -> Rewrite m
topdownR Rewrite m
r = Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Rewrite m -> Rewrite m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
topdownR Rewrite m
r)

{-
Note [topdownFixR]
~~~~~~~~~~~~~~~~~~
'topdownFixR r' is an optimized alternative to some uses of
'repeatR (topdownR r)'. It repeats 'r' top-down, but when a child changes it
only rechecks the ancestors of that child instead of restarting traversal from
the root.

For example, suppose 'r' can rewrite both:

> let x = True in x

to:

> True

and:

> case True of { True -> a; False -> b }

to:

> a

When traversing:

> h (case (let x = True in x) of { True -> a; False -> b })

'topdownFixR r' first cannot rewrite the 'case', so it descends into the
scrutinee. Rewriting the scrutinee exposes a new redex at the parent 'case', so
the parent is checked again immediately and rewritten to 'a'. That change then
bubbles up to 'h a'. With 'repeatR (topdownR r)' the same result is reached by
starting another complete traversal from 'h'.

Only use 'topdownFixR' as a replacement for 'repeatR (topdownR r)' when 'r' is
local and context-stable: it should fire or fail based on the current node, and
the relevant parts of 'TransformContext' should not change when sibling
subtrees are rewritten. Rewrites that inspect let-bound context whose binding
terms may have changed, for example through 'whnfRW', still need an outer
repeat or a normal repeated top-down traversal.
-}

-- | Apply a transformation in a repeated top-down traversal.
--
-- Optimized for local, context-stable transformations. See Note [topdownFixR].
topdownFixR :: Rewrite m -> Rewrite m
topdownFixR :: forall m. Rewrite m -> Rewrite m
topdownFixR Rewrite m
r = Bool -> Rewrite m
go Bool
True
 where
  go :: Bool -> Rewrite m
go Bool
tryParent TransformContext
ctx Term
term = do
    Term
term1 <-
      if Bool
tryParent
        then Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r TransformContext
ctx Term
term
        else Term -> RewriteMonad m Term
forall a. a -> RewriteMonad m a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Term
term
    (Term
term2, Any -> Bool
Monoid.getAny -> Bool
childChanged) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a. RewriteMonad m a -> RewriteMonad m (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (Rewrite m -> Rewrite m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Bool -> Rewrite m
go Bool
True) TransformContext
ctx Term
term1)
    if Bool
childChanged
      then do
        (Term
term3, Any -> Bool
Monoid.getAny -> Bool
parentChanged) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a. RewriteMonad m a -> RewriteMonad m (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r TransformContext
ctx Term
term2)
        if Bool
parentChanged
          then Bool -> Rewrite m
go Bool
False TransformContext
ctx Term
term3
          else Term -> RewriteMonad m Term
forall a. a -> RewriteMonad m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
term3
      else Term -> RewriteMonad m Term
forall a. a -> RewriteMonad m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
term2
{-# INLINE topdownFixR #-}

-- | Apply a transformation in a bottomup traversal
bottomupR :: Monad m => Transform m -> Transform m
bottomupR :: forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR Transform m
r = Transform m -> Transform m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Transform m -> Transform m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR Transform m
r) Transform m -> Transform m -> Transform m
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> Transform m
r

infixr 5 !->
-- | Only apply the second transformation if the first one succeeds.
(!->) :: Rewrite m -> Rewrite m -> Rewrite m
!-> :: forall m. Rewrite m -> Rewrite m -> Rewrite m
(!->) = \Rewrite m
r1 Rewrite m
r2 TransformContext
c Term
expr -> do
  (Term
expr',Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a. RewriteMonad m a -> RewriteMonad m (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
  if Any -> Bool
Monoid.getAny Any
changed
    then Rewrite m
r2 TransformContext
c Term
expr'
    else Term -> RewriteMonad m Term
forall a. a -> RewriteMonad m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr'
{-# INLINE (!->) #-}

infixr 5 >-!
-- | Only apply the second transformation if the first one fails.
(>-!) :: Rewrite m -> Rewrite m -> Rewrite m
>-! :: forall m. Rewrite m -> Rewrite m -> Rewrite m
(>-!) = \Rewrite m
r1 Rewrite m
r2 TransformContext
c Term
expr -> do
  (Term
expr',Any
changed) <- RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a. RewriteMonad m a -> RewriteMonad m (a, Any)
forall w (m :: Type -> Type) a. MonadWriter w m => m a -> m (a, w)
Writer.listen (RewriteMonad m Term -> RewriteMonad m (Term, Any))
-> RewriteMonad m Term -> RewriteMonad m (Term, Any)
forall a b. (a -> b) -> a -> b
$ Rewrite m
r1 TransformContext
c Term
expr
  if Any -> Bool
Monoid.getAny Any
changed
    then Term -> RewriteMonad m Term
forall a. a -> RewriteMonad m a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Term
expr'
    else Rewrite m
r2 TransformContext
c Term
expr'
{-# INLINE (>-!) #-}

-- | Keep applying a transformation until it fails.
repeatR :: Rewrite m -> Rewrite m
repeatR :: forall m. Rewrite m -> Rewrite m
repeatR = let go :: Rewrite m -> Rewrite m
go Rewrite m
r = Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
repeatR Rewrite m
r in Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
go
{-# INLINE repeatR #-}