{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -fno-warn-missing-export-lists #-}
module Debug.DiffExpr where
import Data.Fix (Fix (Fix))
import Debug.SimpleExpr.Expr
( SimpleExpr,
SimpleExprF (SymbolicFuncF),
unaryFunc,
)
import Debug.SimpleExpr.Utils.Traced (Traced (MkTraced))
import Debug.Trace (trace)
import NumHask
( Additive,
Distributive,
Multiplicative,
(*),
(+),
)
import Numeric.InfBackprop (RevDiff (MkRevDiff))
import Prelude (Show, String, show, ($), (<>))
twoArgFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc String
name SimpleExpr
x SimpleExpr
y = SimpleExprF SimpleExpr -> SimpleExpr
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (String -> [SimpleExpr] -> SimpleExprF SimpleExpr
forall a. String -> [a] -> SimpleExprF a
SymbolicFuncF String
name [SimpleExpr
x, SimpleExpr
y])
class SymbolicFunc a where
unarySymbolicFunc :: String -> a -> a
instance SymbolicFunc SimpleExpr where
unarySymbolicFunc :: String -> SimpleExpr -> SimpleExpr
unarySymbolicFunc = String -> SimpleExpr -> SimpleExpr
unaryFunc
instance
(SymbolicFunc a, Multiplicative a) =>
SymbolicFunc (RevDiff t a a)
where
unarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a
unarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a
unarySymbolicFunc String
funcName (MkRevDiff a
x a -> t
bp) =
a -> (a -> t) -> RevDiff t a a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
(String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc String
funcName a
x)
(\a
cy -> a -> t
bp (a -> t) -> a -> t
forall a b. (a -> b) -> a -> b
$ a
f' a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cy)
where
f' :: a
f' = String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'") a
x
class BinarySymbolicFunc a where
binarySymbolicFunc :: String -> a -> a -> a
instance BinarySymbolicFunc SimpleExpr where
binarySymbolicFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
binarySymbolicFunc = String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc
instance
(BinarySymbolicFunc a, Distributive a, Additive t) =>
BinarySymbolicFunc (RevDiff t a a)
where
binarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a -> RevDiff t a a
binarySymbolicFunc String
funcName (MkRevDiff a
x a -> t
bpx) (MkRevDiff a
y a -> t
bpy) =
a -> (a -> t) -> RevDiff t a a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
(String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc String
funcName a
x a
y)
(\a
cz -> a -> t
bpx (a
f'1 a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cz) t -> t -> t
forall a. Additive a => a -> a -> a
+ a -> t
bpy (a
f'2 a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cz))
where
f'1 :: a
f'1 = String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'_1") a
x a
y
f'2 :: a
f'2 = String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'_2") a
x a
y
type TracedSimpleExpr = Traced SimpleExpr
type TSE = TracedSimpleExpr
instance
(SymbolicFunc a, Show a) =>
SymbolicFunc (Traced a)
where
unarySymbolicFunc :: String -> Traced a -> Traced a
unarySymbolicFunc String
name (MkTraced a
x) =
String -> Traced a -> Traced a
forall a. String -> a -> a
trace (String
" <<< TRACING: Calculating " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" >>>") (Traced a -> Traced a) -> Traced a -> Traced a
forall a b. (a -> b) -> a -> b
$
a -> Traced a
forall a. a -> Traced a
MkTraced (a -> Traced a) -> a -> Traced a
forall a b. (a -> b) -> a -> b
$
String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc String
name a
x
instance
(BinarySymbolicFunc a, Show a) =>
BinarySymbolicFunc (Traced a)
where
binarySymbolicFunc :: String -> Traced a -> Traced a -> Traced a
binarySymbolicFunc String
name (MkTraced a
x) (MkTraced a
y) =
String -> Traced a -> Traced a
forall a. String -> a -> a
trace (String
" <<< TRACING: Calculating " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
y String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" >>>") (Traced a -> Traced a) -> Traced a -> Traced a
forall a b. (a -> b) -> a -> b
$
a -> Traced a
forall a. a -> Traced a
MkTraced (a -> Traced a) -> a -> Traced a
forall a b. (a -> b) -> a -> b
$
String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc String
name a
x a
y