{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -fno-warn-missing-export-lists #-}

-- | Module    :  Debug.SimpleExpr
-- Copyright   :  (C) 2023 Alexey Tochin
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  Alexey Tochin <Alexey.Tochin@gmail.com>
--
-- Tools for symbolic differentiation expressions.
module Debug.DiffExpr where

import Data.Fix (Fix (Fix))
import Debug.SimpleExpr.Expr
  ( SimpleExpr,
    SimpleExprF (SymbolicFuncF),
    unaryFunc,
  )
import Debug.SimpleExpr.Utils.Traced (Traced (MkTraced))
import Debug.Trace (trace)
import NumHask
  ( Additive,
    Distributive,
    Multiplicative,
    (*),
    (+),
  )
import Numeric.InfBackprop (RevDiff (MkRevDiff))
import Prelude (Show, String, show, ($), (<>))

-- | Create a binary function expression.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable)
--
-- >>> twoArgFunc "f" (variable "x") (variable "y")
-- f(x,y)
twoArgFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc String
name SimpleExpr
x SimpleExpr
y = SimpleExprF SimpleExpr -> SimpleExpr
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (String -> [SimpleExpr] -> SimpleExprF SimpleExpr
forall a. String -> [a] -> SimpleExprF a
SymbolicFuncF String
name [SimpleExpr
x, SimpleExpr
y])

-- | This typecalss is for creating symbolic unary function expressions.
--
-- It is used in conjunction with automatic differentiation to represent
-- functions symbolically.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable)
-- >>> import Numeric.InfBackprop (simpleDerivative)
--
-- >>> :{
--  f :: SymbolicFunc a => a -> a
--  f = unarySymbolicFunc "f"
-- :}
--
-- >>> f (variable "x")
-- f(x)
--
-- >>> simpleDerivative f (variable "x")
-- f'(x)*1
class SymbolicFunc a where
  unarySymbolicFunc :: String -> a -> a

-- | `SimpleExpr` instance of `SymbolicFunc` typeclass.
instance SymbolicFunc SimpleExpr where
  unarySymbolicFunc :: String -> SimpleExpr -> SimpleExpr
unarySymbolicFunc = String -> SimpleExpr -> SimpleExpr
unaryFunc

-- | `RevDiff` instance of `SymbolicFunc` typeclass.
instance
  (SymbolicFunc a, Multiplicative a) =>
  SymbolicFunc (RevDiff t a a)
  where
  unarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a
  unarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a
unarySymbolicFunc String
funcName (MkRevDiff a
x a -> t
bp) =
    a -> (a -> t) -> RevDiff t a a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
      (String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc String
funcName a
x)
      (\a
cy -> a -> t
bp (a -> t) -> a -> t
forall a b. (a -> b) -> a -> b
$ a
f' a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cy)
    where
      f' :: a
f' = String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'") a
x

-- | This typecalss is for creating symbolic binary function expressions.
--
-- It is used in conjunction with automatic differentiation to represent
-- functions symbolically. See `SymbolicFunc` for unary functions.
class BinarySymbolicFunc a where
  binarySymbolicFunc :: String -> a -> a -> a

-- | `SimpleExpr` instance of `BinarySymbolicFunc` typeclass.
instance BinarySymbolicFunc SimpleExpr where
  binarySymbolicFunc :: String -> SimpleExpr -> SimpleExpr -> SimpleExpr
binarySymbolicFunc = String -> SimpleExpr -> SimpleExpr -> SimpleExpr
twoArgFunc

-- | `RevDiff` instance of `BinarySymbolicFunc` typeclass.
instance
  (BinarySymbolicFunc a, Distributive a, Additive t) =>
  BinarySymbolicFunc (RevDiff t a a)
  where
  binarySymbolicFunc :: String -> RevDiff t a a -> RevDiff t a a -> RevDiff t a a
binarySymbolicFunc String
funcName (MkRevDiff a
x a -> t
bpx) (MkRevDiff a
y a -> t
bpy) =
    a -> (a -> t) -> RevDiff t a a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
      (String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc String
funcName a
x a
y)
      (\a
cz -> a -> t
bpx (a
f'1 a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cz) t -> t -> t
forall a. Additive a => a -> a -> a
+ a -> t
bpy (a
f'2 a -> a -> a
forall a. Multiplicative a => a -> a -> a
* a
cz))
    where
      f'1 :: a
f'1 = String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'_1") a
x a
y
      f'2 :: a
f'2 = String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc (String
funcName String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"'_2") a
x a
y

-- | A traced version of `SimpleExpr` for debugging purposes.
type TracedSimpleExpr = Traced SimpleExpr

-- | A type alias for `Traced` version of `SimpleExpr`.
type TSE = TracedSimpleExpr

-- | `Traced` instance of `SymbolicFunc` typeclass.
instance
  (SymbolicFunc a, Show a) =>
  SymbolicFunc (Traced a)
  where
  unarySymbolicFunc :: String -> Traced a -> Traced a
unarySymbolicFunc String
name (MkTraced a
x) =
    String -> Traced a -> Traced a
forall a. String -> a -> a
trace (String
" <<< TRACING: Calculating " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" >>>") (Traced a -> Traced a) -> Traced a -> Traced a
forall a b. (a -> b) -> a -> b
$
      a -> Traced a
forall a. a -> Traced a
MkTraced (a -> Traced a) -> a -> Traced a
forall a b. (a -> b) -> a -> b
$
        String -> a -> a
forall a. SymbolicFunc a => String -> a -> a
unarySymbolicFunc String
name a
x

-- | `Traced` instance of `BinarySymbolicFunc` typeclass.
instance
  (BinarySymbolicFunc a, Show a) =>
  BinarySymbolicFunc (Traced a)
  where
  binarySymbolicFunc :: String -> Traced a -> Traced a -> Traced a
binarySymbolicFunc String
name (MkTraced a
x) (MkTraced a
y) =
    String -> Traced a -> Traced a
forall a. String -> a -> a
trace (String
" <<< TRACING: Calculating " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" of " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
x String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" and " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> a -> String
forall a. Show a => a -> String
show a
y String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
" >>>") (Traced a -> Traced a) -> Traced a -> Traced a
forall a b. (a -> b) -> a -> b
$
      a -> Traced a
forall a. a -> Traced a
MkTraced (a -> Traced a) -> a -> Traced a
forall a b. (a -> b) -> a -> b
$
        String -> a -> a -> a
forall a. BinarySymbolicFunc a => String -> a -> a -> a
binarySymbolicFunc String
name a
x a
y