{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}

-- | Module    :  Data.Vector.InfBackpropExtra
-- Copyright   :  (C) 2025 Alexey Tochin
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  Alexey Tochin <Alexey.Tochin@gmail.com>
--
-- Backpropagation differentiation core types and functions.
module Numeric.InfBackprop.Core
  ( -- * Common

    -- ** Base
    Tangent,
    Dual,
    Cotangent,
    CT,
    RevDiff (MkRevDiff, value, backprop),
    RevDiff',
    DifferentiableFunc,
    initDiff,
    call,
    derivativeOp,
    toLensOps,
    constDiff,
    StopDiff (stopDiff),
    HasConstant (constant),
    simpleDifferentiableFunc,

    -- ** Relation to lens and profunctors
    toLens,
    fromLens,
    fromProfunctors,
    toProfunctors,
    fromVanLaarhoven,
    toVanLaarhoven,

    -- ** Derivative operators
    AutoDifferentiableArgument,
    DerivativeRoot,
    DerivativeCoarg,
    DerivativeArg,
    AutoDifferentiableValue,
    DerivativeValue,
    autoArg,
    autoVal,
    sameTypeDerivative,
    simpleDerivative,
    simpleValueAndDerivative,
    customArgDerivative,
    customValDerivative,
    customArgValDerivative,

    -- * Differentiable functions

    -- ** Basic
    differentiableSum,
    differentiableSub,
    differentiableNegate,
    differentiableMult,
    differentiableDiv,
    differentiableRecip,
    differentiableMultAction,
    differentiableConv,

    -- ** Exponential and logarithmic functions
    differentiablePow,
    differentiableExp,
    differentiableLog,
    differentiableLogBase,
    differentiableSqrt,

    -- ** Trigonometric functions
    differentiableSin,
    differentiableCos,
    differentiableTan,
    differentiableSinh,
    differentiableCosh,
    differentiableTanh,
    differentiableAsin,
    differentiableAcos,
    differentiableAtan,
    differentiableAtan2,
    differentiableAsinh,
    differentiableAcosh,
    differentiableAtanh,

    -- * Differentiable types

    -- ** Scalar
    scalarArg,
    scalarVal,
    scalarArgDerivative,
    scalarValDerivative,

    -- ** Tuple
    mkTupleArg,
    tupleArg,
    tupleArgDerivative,
    tupleDerivativeOverX,
    tupleDerivativeOverY,
    twoArgsDerivative,
    twoArgsDerivativeOverX,
    twoArgsDerivativeOverY,
    mkTupleVal,
    tupleVal,
    tupleValDerivative,

    -- ** Triple
    threeArgsToTriple,
    tripleArg,
    mkTripleArg,
    tripleArgDerivative,
    tripleDerivativeOverX,
    tripleDerivativeOverY,
    tripleDerivativeOverZ,
    threeArgsDerivative,
    derivative3ArgsOverX,
    derivative3ArgsOverY,
    derivative3ArgsOverZ,
    mkTripleVal,
    tripleVal,
    tripleValDerivative,

    -- ** BoxedVector
    boxedVectorArg,
    mkBoxedVectorArg,
    boxedVectorArgDerivative,
    boxedVectorVal,
    mkBoxedVectorVal,
    boxedVectorValDerivative,

    -- ** Stream
    streamArg,
    mkStreamArg,
    streamArgDerivative,
    streamVal,
    mkStreamVal,
    streamValDerivative,

    -- ** FiniteSupportStream
    finiteSupportStreamArg,
    mkFiniteSupportStreamArg,
    finiteSupportStreamArgDerivative,
    finiteSupportStreamVal,
    mkFiniteSupportStreamVal,
    finiteSupportStreamValDerivative,

    -- ** Maybe
    maybeArg,
    mkMaybeArg,
    maybeArgDerivative,
    maybeVal,
    mkMaybeVal,
    maybeValDerivative,
  )
where

import Control.Applicative ((<$>), (<*>))
import Control.Comonad.Identity (Identity (Identity, runIdentity))
import Control.ExtendableMap (ExtandableMap, extendMap)
import qualified Control.Lens as CL
import Control.Monad.ST (runST)
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Composition ((.:))
import Data.Finite (Finite)
import Data.FiniteSupportStream (FiniteSupportStream (MkFiniteSupportStream, toVector), cons, empty, head, singleton, tail, unsafeMap)
import Data.Function (on)
import Data.Functor.Compose (Compose (Compose, getCompose))
import Data.Functor.Const (Const (Const, getConst))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.List.NonEmpty (xor)
import Data.Primitive (Prim)
import Data.Profunctor (Profunctor (dimap))
import Data.Profunctor.Strong (Costrong (unfirst, unsecond))
import Data.Proxy (Proxy (Proxy))
import Data.Stream (Stream)
import qualified Data.Stream as DS
import Data.Tuple (curry, fst, snd, uncurry)
import Data.Tuple.Extra ((***))
import Data.Type.Equality (type (~))
import qualified Data.Vector as DV
import qualified Data.Vector.Fixed.Boxed as DVFB
import Data.Vector.Fusion.Util (Box (Box, unBox))
import qualified Data.Vector.Generic as DVG
import Data.Vector.Generic.Base
  ( Vector
      ( basicLength,
        basicUnsafeCopy,
        basicUnsafeFreeze,
        basicUnsafeIndexM,
        basicUnsafeSlice,
        basicUnsafeThaw,
        elemseq
      ),
  )
import qualified Data.Vector.Generic.Base as DVGB
import qualified Data.Vector.Generic.Mutable as DVGM
import qualified Data.Vector.Generic.Mutable.Base as DVGBM
import qualified Data.Vector.Generic.Sized as DVGS
import qualified Data.Vector.Generic.Sized.Internal as DVGSI
import qualified Data.Vector.Primitive as DVP
import qualified Data.Vector.Unboxed as DVU
import qualified Data.Vector.Unboxed.Mutable as DVUM
import Data.Word (Word, Word16, Word32, Word64, Word8)
import Debug.SimpleExpr (SimpleExpr, SimpleExprF)
import Debug.SimpleExpr.Expr (SE, number)
import Debug.SimpleExpr.Utils.Algebra
  ( AlgebraicPower ((^^)),
    Convolution ((|*|)),
    IntegerPower,
    MultiplicativeAction ((*|)),
    (^),
  )
import Debug.SimpleExpr.Utils.Traced (Traced (MkTraced))
import Debug.Trace (trace)
import Foreign (oneBits)
import GHC.Base
  ( Applicative,
    Eq ((==)),
    Float,
    Functor,
    Int,
    Maybe (Just, Nothing),
    Ord (compare, max, min, (<), (<=), (>), (>=)),
    Type,
    const,
    flip,
    fmap,
    id,
    pure,
    return,
    undefined,
    ($),
    (++),
    (.),
    (<*>),
  )
import GHC.Generics (C, Generic, type (:.:) (unComp1))
import GHC.Integer (Integer)
import GHC.Natural (Natural)
import qualified GHC.Num as GHCN
import GHC.Real (Integral, fromIntegral, realToFrac, toInteger)
import qualified GHC.Real as GHCR
import GHC.Show (Show (show))
import GHC.TypeLits (KnownChar)
import GHC.TypeNats (KnownNat, Nat)
import GHC.Types (Int)
import NumHask
  ( Additive,
    AdditiveAction,
    Complex,
    Distributive,
    Divisive,
    ExpField,
    Field,
    FromInteger (fromInteger),
    FromIntegral,
    Multiplicative,
    Subtractive,
    TrigField,
    acos,
    acosh,
    asin,
    asinh,
    atan,
    atan2,
    atanh,
    cos,
    cosh,
    exp,
    fromIntegral,
    log,
    logBase,
    negate,
    one,
    pi,
    recip,
    sin,
    sinh,
    sqrt,
    tan,
    tanh,
    two,
    zero,
    (*),
    (**),
    (+),
    (-),
    (/),
  )
import NumHask.Data.Integral (FromInteger)
import Numeric.InfBackprop.Instances.NumHask ()
import Numeric.InfBackprop.Utils.SizedVector (BoxedVector, boxedVectorBasis, boxedVectorSum)
import Numeric.InfBackprop.Utils.Tuple (cross, cross3, curry3, fork, fork3, uncurry3)
import Optics (Lens, Lens', getting, lens, set, simple, view, (%))

-- | Converts a type into its tangent space type.
type family Tangent (a :: Type) :: Type

type instance Tangent Float = Float

type instance Tangent GHCN.Integer = GHCN.Integer

type instance Tangent SimpleExpr = SimpleExpr

type instance Tangent (a0, a1) = (Tangent a0, Tangent a1)

type instance Tangent (a0, a1, a2) = (Tangent a0, Tangent a1, Tangent a2)

type instance Tangent [a] = [Tangent a]

type instance Tangent (DVFB.Vec n a) = DVFB.Vec n (Tangent a)

type instance Tangent (DVGS.Vector v n a) = DVGS.Vector v n (Tangent a)

type instance Tangent (Stream a) = Stream (Tangent a)

type instance Tangent (FiniteSupportStream a) = FiniteSupportStream (Tangent a)

type instance Tangent (Maybe a) = Maybe (Tangent a)

type instance Tangent (Traced a) = Traced (Tangent a)

type instance Tangent (Complex a) = Complex (Tangent a)

-- | Converts a type into its dual space type.
type family Dual (x :: Type) :: Type

type instance Dual Float = Float

type instance Dual GHCN.Integer = GHCN.Integer

type instance Dual SimpleExpr = SimpleExpr

type instance Dual (a, b) = (Dual a, Dual b)

type instance Dual (a, b, c) = (Dual a, Dual b, Dual c)

type instance Dual [a] = [Dual a]

type instance Dual (DVFB.Vec n a) = DVFB.Vec n (Dual a)

type instance Dual (DVGS.Vector v n a) = DVGS.Vector v n (Dual a)

type instance Dual (Stream a) = FiniteSupportStream (Dual a)

type instance Dual (FiniteSupportStream a) = Stream (Dual a)

type instance Dual (SimpleExprF a) = SimpleExprF (Dual a)

type instance Dual (Maybe a) = Maybe (Dual a)

type instance Dual (Traced a) = Traced (Dual a)

type instance Dual (Complex a) = Complex (Dual a)

-- | Cotangent type alias.
type Cotangent a = Dual (Tangent a)

-- | Cotangent type alias.
type CT a = Cotangent a

-- | Base type for differentiable instances with the backpropagation.
--
-- ==== __Examples__
--
-- >>> :{
--  differentiableSin_ :: RevDiff t Float Float -> RevDiff t Float Float
--  differentiableSin_ (MkRevDiff v bp) = MkRevDiff (sin v) (bp . (cos v *))
-- :}
--
-- >>> value $ differentiableSin_ (MkRevDiff 0.0 id)
-- 0.0
--
-- >>> backprop (differentiableSin_ (MkRevDiff 0.0 id)) 1.0
-- 1.0
--
-- === `GHC.Num.Num` typeclass instance
--
-- This instance enables the use of standard numeric operations and literals
-- directly with `RevDiff` values, simplifying the syntax for
-- automatic differentiation computations.--
-- The instance supports `GHC.Num.Num` operations including arithmetic
-- operators @(+), (-), (*)@, comparison functions (`GHC.Num.abs`, `GHC.Num.signum`), and automatic
-- conversion from integer literals via `fromInteger`.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import GHC.Integer (Integer)
--
-- >>> x = variable "x"
--
-- ===== Using numeric literals in automatic differentiation
--
-- This instance allows `RevDiff` values to be created directly from integer
-- literals, eliminating the need for explicit conversion functions.
--
-- Consider computing the partial derivative:
--
-- \[
--  \left.\frac{\partial}{\partial y} (x \cdot y)\right|_{y=2}
-- \]
--
-- Without the `GHC.Num.Num` instance, we would need explicit conversion:
--
-- >>> simplify $ twoArgsDerivativeOverY (*) x (stopDiff $ number 2) :: SE
-- x
--
-- With the `GHC.Num.Num` instance for `RevDiff`, this simplifies to:
--
-- >>> simplify $ twoArgsDerivativeOverY (*) x (number 2) :: SE
-- x
--
-- And combined with the `GHC.Num.Num` instance for `SE`,
-- we achieve the most concise form:
--
-- >>> simplify $ twoArgsDerivativeOverY (*) x 2
-- x
--
-- This progression shows how the typeclass instances work together to enable
-- increasingly natural mathematical notation.
--
-- ===== Power function differentiation
--
-- The instance enables natural exponentiation syntax with automatic differentiation:
--
-- >>> x ** 3 :: SE
-- x^3
-- >>> simplify $ simpleDerivative (** 3) x :: SE
-- 3*(x^2)
-- >>> simplify $ simpleDerivative (simpleDerivative (** 3)) x :: SE
-- (2*x)*3
--
-- ===== Absolute value and signum functions
--
-- The instance provides symbolic differentiation for absolute value and signum:
--
-- >>> simplify $ simpleDerivative GHCN.abs (variable "x") :: SE
-- sign(x)
--
-- >>> simplify $ simpleDerivative GHCN.signum (variable "x") :: SE
-- 0
--
-- For numeric evaluation, the second derivative of absolute value at a point
-- gives the expected result:
--
-- >>> (simpleDerivative (simpleDerivative GHCN.abs)) (1 :: Float) :: Float
-- 0.0
--
-- Notice that the signum function returns zero for all values, including zero.
--
-- >>> simpleDerivative GHCN.signum (0 :: Float) :: Float
-- 0.0
--
-- >>> simplify $ (simpleDerivative (simpleDerivative GHCN.abs)) (variable "x") :: SE
-- 0
--
-- === `GHCR.Fractional` typeclass instance
--
-- Thank to this instance we can use numerical literals like '1.0', '2.0', etc.,
-- see the examples below.
--
-- ==== __Examples__
--
-- >>> import GHC.Float (Float)
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
--
-- >>> f x = 8 / x
-- >>> simpleDerivative f (2.0 :: Float)
-- -2.0
-- >>> simplify $ simpleDerivative f (variable "x") :: SE
-- -((8/x)/x)
data RevDiff a b c = MkRevDiff {forall a b c. RevDiff a b c -> c
value :: c, forall a b c. RevDiff a b c -> b -> a
backprop :: b -> a}
  deriving ((forall x. RevDiff a b c -> Rep (RevDiff a b c) x)
-> (forall x. Rep (RevDiff a b c) x -> RevDiff a b c)
-> Generic (RevDiff a b c)
forall x. Rep (RevDiff a b c) x -> RevDiff a b c
forall x. RevDiff a b c -> Rep (RevDiff a b c) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a b c x. Rep (RevDiff a b c) x -> RevDiff a b c
forall a b c x. RevDiff a b c -> Rep (RevDiff a b c) x
$cfrom :: forall a b c x. RevDiff a b c -> Rep (RevDiff a b c) x
from :: forall x. RevDiff a b c -> Rep (RevDiff a b c) x
$cto :: forall a b c x. Rep (RevDiff a b c) x -> RevDiff a b c
to :: forall x. Rep (RevDiff a b c) x -> RevDiff a b c
Generic)

-- | Type alias for common case where the backpropagation is in the cotangent space.
type RevDiff' a b = RevDiff (CT a) (CT b) b

type instance Tangent (RevDiff a b c) = RevDiff a (Tangent b) (Tangent c)

type instance Dual (RevDiff a b c) = RevDiff a (Dual b) (Dual c)

-- | Converts a differentiable function into a regular function.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable)
-- >>> import Debug.DiffExpr (unarySymbolicFunc)
--
-- >>> :{
--  differentiableCos_ :: RevDiff t Float Float -> RevDiff t Float Float
--  differentiableCos_ (MkRevDiff v bp) = MkRevDiff (cos v) (bp . negate . (sin v *))
-- :}
--
-- >>> call differentiableCos_ 0.0
-- 1.0
--
-- >>> x = variable "x"
-- >>> f = unarySymbolicFunc "f"
-- >>> f x
-- f(x)
--
-- >>> call f x
-- f(x)
call :: (RevDiff' a a -> RevDiff' a b) -> a -> b
call :: forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> b
call RevDiff' a a -> RevDiff' a b
f = RevDiff' a b -> b
forall a b c. RevDiff a b c -> c
value (RevDiff' a b -> b) -> (a -> RevDiff' a b) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a a -> RevDiff' a b
f (RevDiff' a a -> RevDiff' a b)
-> (a -> RevDiff' a a) -> a -> RevDiff' a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff

-- | Converts a differentiable function into into its derivative in the form of
-- multiplicative operator.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable)
-- >>> import Debug.DiffExpr (unarySymbolicFunc)
--
-- >>> :{
--   differentiableSin_ :: RevDiff t Float Float -> RevDiff t Float Float
--   differentiableSin_ (MkRevDiff v bp) = MkRevDiff (sin v) (bp . (cos v *))
-- :}
--
-- >>> (derivativeOp differentiableSin_ 0.0) 1.0
-- 1.0
--
-- >>> c = variable "c"
-- >>> x = variable "x"
-- >>> f = unarySymbolicFunc "f"
-- >>> f x
-- f(x)
-- >>> (derivativeOp f x) c
-- f'(x)*c
derivativeOp :: (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp :: forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp RevDiff' a a -> RevDiff' a b
f = RevDiff' a b -> Dual (Tangent b) -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff' a b -> Dual (Tangent b) -> Dual (Tangent a))
-> (a -> RevDiff' a b) -> a -> Dual (Tangent b) -> Dual (Tangent a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a a -> RevDiff' a b
f (RevDiff' a a -> RevDiff' a b)
-> (a -> RevDiff' a a) -> a -> RevDiff' a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff

-- | Converts a function into a pair of its value and backpropagation function,
-- which are the lense get and set functions, respectively.
toLensOps :: (RevDiff ca ca a -> RevDiff ca cb b) -> a -> (b, cb -> ca)
toLensOps :: forall ca a cb b.
(RevDiff ca ca a -> RevDiff ca cb b) -> a -> (b, cb -> ca)
toLensOps RevDiff ca ca a -> RevDiff ca cb b
f a
x = (b
y, cb -> ca
bp)
  where
    MkRevDiff b
y cb -> ca
bp = RevDiff ca ca a -> RevDiff ca cb b
f (RevDiff ca ca a -> RevDiff ca cb b)
-> RevDiff ca ca a -> RevDiff ca cb b
forall a b. (a -> b) -> a -> b
$ a -> RevDiff ca ca a
forall a b. a -> RevDiff b b a
initDiff a
x

-- | Creates a differentiable function from a function and its derivative.
-- This is a convenience function for defining new differentiable operations.
--
-- ==== __Examples__
--
-- >>> :{
--  differentiableCos_ :: RevDiff t Float Float -> RevDiff t Float Float
--  differentiableCos_ = simpleDifferentiableFunc cos (negate . sin)
-- :}
--
-- >>> call differentiableCos_ 0.0
-- 1.0
--
-- >>> simpleDerivative differentiableCos_ 0.0
-- -0.0
simpleDifferentiableFunc ::
  (Multiplicative b) =>
  (b -> b) ->
  (b -> b) ->
  RevDiff a b b ->
  RevDiff a b b
simpleDifferentiableFunc :: forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
f b -> b
f' (MkRevDiff b
x b -> a
bpc) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
f b
x) (\b
cy -> b -> a
bpc (b -> a) -> b -> a
forall a b. (a -> b) -> a -> b
$ b -> b
f' b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy)

-- | Initializes a `MkRevDiff` instance with given value
-- and identity backpropagation function.
-- This is useful for starting the backpropagation chain.
--
-- ==== __Examples__
--
-- >>> :{
--   differentiableCos_ :: RevDiff t Float Float -> RevDiff t Float Float
--   differentiableCos_ (MkRevDiff v bp) = MkRevDiff (cos v) (bp . negate . (sin v *))
-- :}
--
-- >>> value $ differentiableCos_ (initDiff 0.0)
-- 1.0
--
-- >>> backprop (differentiableCos_ (initDiff 0.0)) 1.0
-- -0.0
initDiff :: a -> RevDiff b b a
initDiff :: forall a b. a -> RevDiff b b a
initDiff a
x = a -> (b -> b) -> RevDiff b b a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff a
x b -> b
forall a. a -> a
id

-- | Converts a differentiable function into a /law-breaking/ 'Lens'.
-- This is mutually inverse with 'fromLens'.
--
-- ==== __Examples__
--
-- >>> import Optics (Lens', lens, view, set, getting, (%))
-- >>> import Debug.SimpleExpr (variable, SE)
--
-- >>> sinLens = toLens sin :: Lens' SE SE
-- >>> x = variable "x"
-- >>> c = variable "c"
-- >>> (view . getting) sinLens x
-- sin(x)
-- >>> set sinLens c x
-- cos(x)*c
-- >>> squareLens = toLens (^2) :: Lens' SE SE
-- >>> (view . getting) (squareLens % sinLens) x
-- sin(x^2)
toLens :: (RevDiff b b a -> RevDiff b d c) -> Lens a b c d
toLens :: forall b a d c. (RevDiff b b a -> RevDiff b d c) -> Lens a b c d
toLens RevDiff b b a -> RevDiff b d c
f = (a -> c) -> (a -> d -> b) -> Lens a b c d
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens (RevDiff b d c -> c
forall a b c. RevDiff a b c -> c
value (RevDiff b d c -> c) -> (a -> RevDiff b d c) -> a -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b d c
bp) (RevDiff b d c -> d -> b
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff b d c -> d -> b) -> (a -> RevDiff b d c) -> a -> d -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b d c
bp)
  where
    bp :: a -> RevDiff b d c
bp = RevDiff b b a -> RevDiff b d c
f (RevDiff b b a -> RevDiff b d c)
-> (a -> RevDiff b b a) -> a -> RevDiff b d c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b b a
forall a b. a -> RevDiff b b a
initDiff

-- | Converts a /law-breaking/ 'Lens' into a differentiable function.
-- This is mutually inverse with 'toLens'.
--
-- ==== __Examples__
--
-- >>> import Optics (lens)
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
--
-- >>> sinV2 = fromLens $ lens sin (\x -> (cos x *))
-- >>> x = variable "x"
-- >>> c = variable "c"
-- >>> call sinV2 x
-- sin(x)
-- >>> simplify $ simpleDerivative sinV2 x :: SE
-- cos(x)
fromLens :: Lens a (CT a) b (CT b) -> RevDiff' a a -> RevDiff' a b
fromLens :: forall a b. Lens a (CT a) b (CT b) -> RevDiff' a a -> RevDiff' a b
fromLens Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l (MkRevDiff a
x Dual (Tangent a) -> Dual (Tangent a)
bp) = b
-> (Dual (Tangent b) -> Dual (Tangent a))
-> RevDiff (Dual (Tangent a)) (Dual (Tangent b)) b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((Optic' A_Getter NoIx a b -> a -> b
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view (Optic' A_Getter NoIx a b -> a -> b)
-> (Lens a (Dual (Tangent a)) b (Dual (Tangent b))
    -> Optic' A_Getter NoIx a b)
-> Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> a
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' A_Getter NoIx a b
Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' (ReadOnlyOptic A_Lens) NoIx a b
forall (is :: IxList).
Optic A_Lens is a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' (ReadOnlyOptic A_Lens) is a b
forall k s t a b (is :: IxList).
ToReadOnly k s t a b =>
Optic k is s t a b -> Optic' (ReadOnlyOptic k) is s a
getting) Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l a
x) (\Dual (Tangent b)
cy -> Dual (Tangent a) -> Dual (Tangent a)
bp (Dual (Tangent a) -> Dual (Tangent a))
-> Dual (Tangent a) -> Dual (Tangent a)
forall a b. (a -> b) -> a -> b
$ Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Dual (Tangent b) -> a -> Dual (Tangent a)
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a b -> b -> s -> t
set Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l Dual (Tangent b)
cy a
x)

-- | Profunctor instance for `RevDiff`.
instance Profunctor (RevDiff t) where
  dimap :: (a -> b) -> (c -> d) -> RevDiff t b c -> RevDiff t a d
  dimap :: forall a b c d.
(a -> b) -> (c -> d) -> RevDiff t b c -> RevDiff t a d
dimap a -> b
f c -> d
g (MkRevDiff c
v b -> t
bp) = d -> (a -> t) -> RevDiff t a d
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c -> d
g c
v) (b -> t
bp (b -> t) -> (a -> b) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

-- | Costrong instance for `RevDiff`.
instance Costrong (RevDiff t) where
  unfirst :: RevDiff t (a, d) (b, d) -> RevDiff t a b
  unfirst :: forall a d b. RevDiff t (a, d) (b, d) -> RevDiff t a b
unfirst (MkRevDiff (b, d)
v (a, d) -> t
bp) = b -> (a -> t) -> RevDiff t a b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((b, d) -> b
forall a b. (a, b) -> a
fst (b, d)
v) ((a, d) -> t
bp ((a, d) -> t) -> (a -> (a, d)) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,(b, d) -> d
forall a b. (a, b) -> b
snd (b, d)
v))
  unsecond :: RevDiff t (d, a) (d, b) -> RevDiff t a b
  unsecond :: forall d a b. RevDiff t (d, a) (d, b) -> RevDiff t a b
unsecond (MkRevDiff (d, b)
v (d, a) -> t
bp) = b -> (a -> t) -> RevDiff t a b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((d, b) -> b
forall a b. (a, b) -> b
snd (d, b)
v) ((d, a) -> t
bp ((d, a) -> t) -> (a -> (d, a)) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((d, b) -> d
forall a b. (a, b) -> a
fst (d, b)
v,))

-- | Type `DifferentiableFunc`@ a b@ may be associated with the differentiable
-- functions from @a@ to @b@.
-- Composition `(.)` of
-- @DifferentiableFunc b c@ and @DifferentiableFunc a b@ is @DifferentiableFunc a c@
-- by definition.
--
-- See `fromProfunctors`, `toProfunctors`, `fromVanLaarhoven` and `fromVanLaarhoven`
-- for illustraing how to use this type.
--
-- ==== __Examples__
--
-- >>> :{
--  differentiableCos_ :: DifferentiableFunc Float Float
--  differentiableCos_ (MkRevDiff x bpc) = MkRevDiff (cos x) (bpc . ((negate $ sin x) *))
-- :}
--
-- >>> call differentiableCos_ 0.0
-- 1.0
--
-- >>> simpleDerivative differentiableCos_ 0.0
-- -0.0
type DifferentiableFunc a b = forall t. RevDiff t (CT a) a -> RevDiff t (CT b) b

-- Profunctor and Van Laarhoven representations.

-- | Transorfms profunctor (Costrong) map into a 'RevDiff' map.
-- Inverse of 'toProfunctors'.
fromProfunctors ::
  (forall p. (Costrong p) => p (CT a) a -> p (CT b) b) -> DifferentiableFunc a b
fromProfunctors :: forall a b.
(forall (p :: * -> * -> *). Costrong p => p (CT a) a -> p (CT b) b)
-> DifferentiableFunc a b
fromProfunctors = (RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b)
-> RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b
(forall (p :: * -> * -> *).
 Costrong p =>
 p (Dual (Tangent a)) a -> p (Dual (Tangent b)) b)
-> RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b
forall a. a -> a
id

-- | Profunctor representation of the `RevDiff` like for lens map in the spirit of optics.
-- Inverse of `fromProfunctors`.
toProfunctors ::
  -- (RevDiff a a -> RevDiff a b) ->
  -- (RevDiff (CT a) (CT a) a -> RevDiff (CT a) (CT b) b) ->
  (Costrong p) =>
  DifferentiableFunc a b ->
  p (CT a) a ->
  p (CT b) b
toProfunctors :: forall (p :: * -> * -> *) a b.
Costrong p =>
DifferentiableFunc a b -> p (CT a) a -> p (CT b) b
toProfunctors DifferentiableFunc a b
f = p (a, Dual (Tangent b)) (a, b) -> p (Dual (Tangent b)) b
forall d a b. p (d, a) (d, b) -> p a b
forall (p :: * -> * -> *) d a b.
Costrong p =>
p (d, a) (d, b) -> p a b
unsecond (p (a, Dual (Tangent b)) (a, b) -> p (Dual (Tangent b)) b)
-> (p (Dual (Tangent a)) a -> p (a, Dual (Tangent b)) (a, b))
-> p (Dual (Tangent a)) a
-> p (Dual (Tangent b)) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dual (Tangent b)) -> Dual (Tangent a))
-> (a -> (a, b))
-> p (Dual (Tangent a)) a
-> p (a, Dual (Tangent b)) (a, b)
forall a b c d. (a -> b) -> (c -> d) -> p b c -> p a d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap ((a -> Dual (Tangent b) -> Dual (Tangent a))
-> (a, Dual (Tangent b)) -> Dual (Tangent a)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> Dual (Tangent b) -> Dual (Tangent a)
u) ((a -> a) -> (a -> b) -> a -> (a, b)
forall t a b. (t -> a) -> (t -> b) -> t -> (a, b)
fork a -> a
forall a. a -> a
id a -> b
v)
  where
    v :: a -> b
v = (RevDiff' a a -> RevDiff' a b) -> a -> b
forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> b
call RevDiff' a a -> RevDiff' a b
DifferentiableFunc a b
f
    u :: a -> Dual (Tangent b) -> Dual (Tangent a)
u = (RevDiff' a a -> RevDiff' a b)
-> a -> Dual (Tangent b) -> Dual (Tangent a)
forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp RevDiff' a a -> RevDiff' a b
DifferentiableFunc a b
f

-- Van Laarhoven representation of the `RevDiff` type.

-- | Converts a Van Laarhoven representation to a function over `RevDiff` types
-- Inverse of `toVanLaarhoven`.
fromVanLaarhoven ::
  forall a b.
  (forall f. (Functor f) => (b -> f (CT b)) -> a -> f (CT a)) ->
  DifferentiableFunc a b
-- RevDiff t a ->
-- RevDiff t b
fromVanLaarhoven :: forall a b.
(forall (f :: * -> *).
 Functor f =>
 (b -> f (CT b)) -> a -> f (CT a))
-> DifferentiableFunc a b
fromVanLaarhoven forall (f :: * -> *).
Functor f =>
(b -> f (Dual (Tangent b))) -> a -> f (CT a)
vll (MkRevDiff a
x CT a -> t
bpx) = b -> (Dual (Tangent b) -> t) -> RevDiff t (Dual (Tangent b)) b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (CT a -> t
bpx (CT a -> t) -> (Dual (Tangent b) -> CT a) -> Dual (Tangent b) -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dual (Tangent b) -> CT a
bp)
  where
    (b
y, Dual (Tangent b) -> CT a
bp) = Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
-> (b, Dual (Tangent b) -> CT a)
forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
 -> (b, Dual (Tangent b) -> CT a))
-> Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
-> (b, Dual (Tangent b) -> CT a)
forall a b. (a -> b) -> a -> b
$ (b -> Compose ((,) b) ((->) (Dual (Tangent b))) (Dual (Tangent b)))
-> a -> Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
forall (f :: * -> *).
Functor f =>
(b -> f (Dual (Tangent b))) -> a -> f (CT a)
vll (\b
y_ -> (b, Dual (Tangent b) -> Dual (Tangent b))
-> Compose ((,) b) ((->) (Dual (Tangent b))) (Dual (Tangent b))
forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (b
y_, Dual (Tangent b) -> Dual (Tangent b)
forall a. a -> a
id)) a
x

-- | Converts a function over `RevDiff` types into a Van Laarhoven representation.
-- Inverse of `fromVanLaarhoven`.
toVanLaarhoven ::
  (Functor f) =>
  -- (RevDiff a a -> RevDiff a b) ->
  DifferentiableFunc a b ->
  (b -> f (CT b)) ->
  a ->
  f (CT a)
toVanLaarhoven :: forall (f :: * -> *) a b.
Functor f =>
DifferentiableFunc a b -> (b -> f (CT b)) -> a -> f (CT a)
toVanLaarhoven DifferentiableFunc a b
g b -> f (CT b)
f a
x = (CT b -> Dual (Tangent a)) -> f (CT b) -> f (Dual (Tangent a))
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CT b -> Dual (Tangent a)
bp (b -> f (CT b)
f b
y)
  where
    MkRevDiff b
y CT b -> Dual (Tangent a)
bp = RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (CT b) b
DifferentiableFunc a b
g (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
 -> RevDiff (Dual (Tangent a)) (CT b) b)
-> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (CT b) b
forall a b. (a -> b) -> a -> b
$ a -> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
forall a b. a -> RevDiff b b a
initDiff a
x

-- -- | Performs backpropagation starting from 'one' and returns the result.
-- -- In particular,
-- -- for constant functions, this will return zero since their derivative is zero.
-- --
-- -- ==== __Examples__
-- --
-- -- >>> diff $ initDiff (42.0 :: Float) :: Float
-- -- 1.0
-- --
-- -- >>> diff (constDiff 42.0 :: RevDiff Float Float Float) :: Float
-- -- 0.0
-- diff :: (Multiplicative b) => RevDiff a b c -> a
-- diff x = backprop x one

-- | Creates a constant differentiable function.
-- The derivative of a constant function is always zero.
--
-- ==== __Examples__
--
-- >>> value (constDiff 42.0 :: RevDiff' Float Float)
-- 42.0
--
-- >>> backprop (constDiff 42.0 :: RevDiff' Float Float) 1.0
-- 0.0
constDiff :: (Additive a) => c -> RevDiff a b c
constDiff :: forall a c b. Additive a => c -> RevDiff a b c
constDiff c
x = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (a -> b -> a
forall a b. a -> b -> a
const a
forall a. Additive a => a
zero)

-- | Derivative for a scalar-to-scalar function.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SimpleExpr)
-- >>> import Debug.DiffExpr (unarySymbolicFunc)
--
-- >>> simpleDerivative sin (0.0 :: Float)
-- 1.0
--
-- >>> x = variable "x"
--
-- >>> simplify $ simpleDerivative (^ 2) x
-- 2*x
--
-- >>> f = unarySymbolicFunc "f"
--
-- >>> simplify $ simpleDerivative f x :: SimpleExpr
-- f'(x)
simpleDerivative ::
  forall a b.
  (Multiplicative (CT b)) =>
  (RevDiff' a a -> RevDiff' a b) ->
  a ->
  CT a
simpleDerivative :: forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> CT a
simpleDerivative RevDiff' a a -> RevDiff' a b
f a
x = RevDiff' a b -> CT b -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff' a a -> RevDiff' a b
f (a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff a
x)) CT b
forall a. Multiplicative a => a
one

-- | Derivative of a function from any type to the same type.
-- The type structure of the input and output values must be the same.
--
-- ==== __Examples__
--
-- >>> f = sin :: TrigField a => a -> a
-- >>> f' = sameTypeDerivative f :: Float -> Float
--
-- >>> f' 0.0
-- 1.0
sameTypeDerivative ::
  (Multiplicative (CT a)) =>
  (RevDiff (CT a) (CT a) a -> RevDiff (CT a) (CT a) a) ->
  a ->
  CT a
sameTypeDerivative :: forall a.
Multiplicative (CT a) =>
(RevDiff (CT a) (CT a) a -> RevDiff (CT a) (CT a) a) -> a -> CT a
sameTypeDerivative = (RevDiff' a a -> RevDiff' a a) -> a -> CT a
forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> CT a
simpleDerivative

-- | Returns both the value and the derivative for a scalar-to-scalar function.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SimpleExpr)
-- >>> import Debug.DiffExpr (unarySymbolicFunc)
--
-- >>> simpleValueAndDerivative sin (0.0 :: Float)
-- (0.0,1.0)
--
-- >>> x = variable "x"
-- >>> f = unarySymbolicFunc "f"
--
-- >>> simplify $ simpleValueAndDerivative f x :: (SimpleExpr, SimpleExpr)
-- (f(x),f'(x))
simpleValueAndDerivative ::
  forall a b.
  (Multiplicative (CT b)) =>
  (RevDiff' a a -> RevDiff' a b) ->
  a ->
  (b, CT a)
simpleValueAndDerivative :: forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> (b, CT a)
simpleValueAndDerivative RevDiff' a a -> RevDiff' a b
f a
x = (RevDiff' a b -> b
forall a b c. RevDiff a b c -> c
value RevDiff' a b
out, RevDiff' a b -> CT b -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop RevDiff' a b
out CT b
forall a. Multiplicative a => a
one)
  where
    out :: RevDiff' a b
out = RevDiff' a a -> RevDiff' a b
f (a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff a
x)

-- | Derivative of a function from any type to any type.
-- The type structure of the input and output values must be specified
-- in the first and second arguments, respectively.
-- The output value type of the derivative is infereced automatically.
--
-- ==== __Examples__
--
-- >>> :{
--    sphericToVec :: (TrigField a) =>
--      (a, a) -> BoxedVector 3 a
--    sphericToVec (theta, phi) = DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta)
-- :}
--
-- >>> sphericToVec' = customArgValDerivative tupleArg boxedVectorVal sphericToVec
--
-- Here 'tupleArg' manifests that the argument type is a tuple.
-- The second term 'boxedVectorVal' specifies that the output value type is a boxed vector.
--
-- >>> sphericToVec' (0 :: Float, 0 :: Float)
-- Vector [(0.0,0.0),(0.0,1.0),(1.0,0.0)]
customArgValDerivative ::
  (RevDiff (CT a) (CT a) a -> b) ->
  (c -> d) ->
  (b -> c) ->
  a ->
  d
customArgValDerivative :: forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (CT a) (CT a) a -> b
argTerm c -> d
valTerm b -> c
f = c -> d
valTerm (c -> d) -> (a -> c) -> a -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> c
f (b -> c) -> (a -> b) -> a -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a) (CT a) a -> b
argTerm (RevDiff (CT a) (CT a) a -> b)
-> (a -> RevDiff (CT a) (CT a) a) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff (CT a) (CT a) a
forall a b. a -> RevDiff b b a
initDiff

-- | Axulary type for building nested argument structure descriptors.
type RevDiffArg a b c d = RevDiff a b c -> d

-- | Typeclass needed for the automatic agrument descriptor derivation.
-- See instance implementations for `RevDiff`, tuple and `BoxedVector` below.
--
-- ==== __Examples__
--
-- >>> :{
--  sphericToVector :: (TrigField a) =>
--    (a, a) -> BoxedVector 3 a
--  sphericToVector (theta, phi) =
--    DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta)
-- :}
--
-- >>> sphericToVector' = customArgValDerivative autoArg boxedVectorVal sphericToVector
-- >>> sphericToVector' (0 :: Float, 0 :: Float)
-- Vector [(0.0,0.0),(0.0,1.0),(1.0,0.0)]
class
  (Additive (DerivativeRoot a), Additive (DerivativeCoarg a)) =>
  AutoDifferentiableArgument a
  where
  -- | Differentiable function root
  type DerivativeRoot a :: Type

  -- | Differentiable function coargument
  type DerivativeCoarg a :: Type

  -- | Differentiable functin argument
  type DerivativeArg a :: Type

  -- | Automatic argument descriptor.
  autoArg :: RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) -> a

-- | `AutoDifferentiableArgument` instance for the scalar argument term.
instance
  (Additive a, Additive b) =>
  AutoDifferentiableArgument (RevDiff a b c)
  where
  type DerivativeRoot (RevDiff a b c) = a
  type DerivativeCoarg (RevDiff a b c) = b
  type DerivativeArg (RevDiff a b c) = c
  autoArg :: RevDiff
  (DerivativeRoot (RevDiff a b c))
  (DerivativeCoarg (RevDiff a b c))
  (DerivativeArg (RevDiff a b c))
-> RevDiff a b c
autoArg = RevDiff a b c -> RevDiff a b c
RevDiff
  (DerivativeRoot (RevDiff a b c))
  (DerivativeCoarg (RevDiff a b c))
  (DerivativeArg (RevDiff a b c))
-> RevDiff a b c
forall a. a -> a
id

-- | Typeclass needed for the automatic value term derivation.
--
-- ==== __Examples__
--
-- >>> :{
--    sphericToVector :: (TrigField a) =>
--      (a, a) -> BoxedVector 3 a
--    sphericToVector (theta, phi) = DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta)
-- :}
--
-- >>> sphericToVector' = customArgValDerivative tupleArg autoVal sphericToVector
-- >>> sphericToVector' (0 :: Float, 0 :: Float)
-- Vector [(0.0,0.0),(0.0,1.0),(1.0,0.0)]
class AutoDifferentiableValue a where
  -- | Differentiable function value type.
  type DerivativeValue a :: Type

  -- | Automatic value descriptor.
  autoVal :: a -> DerivativeValue a

-- | Scalar value term.
--
-- ==== __Examples__
--
-- >>> :{
--    product :: (Multiplicative a) => (a, a) -> a
--    product (x, y) = x * y
-- :}
--
-- >>> product' = customArgValDerivative tupleArg scalarVal product
--
-- >>> product' (2 :: Float, 3 :: Float)
-- (3.0,2.0)
--
-- >>> import Debug.SimpleExpr (variable, simplify, SimpleExpr)
-- >>> x = variable "x"
-- >>> y = variable "y"
-- >>> simplify $ product' (x, y) :: (SimpleExpr, SimpleExpr)
-- (y,x)
scalarVal ::
  (Multiplicative b) =>
  RevDiff a b c ->
  a
scalarVal :: forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal (MkRevDiff c
_ b -> a
bp) = b -> a
bp b
forall a. Multiplicative a => a
one

-- | `AutoDifferentiableValue` instance for the scalar value term.
instance
  (Multiplicative b) =>
  AutoDifferentiableValue (RevDiff a b c)
  where
  type DerivativeValue (RevDiff a b c) = a
  autoVal :: RevDiff a b c -> a
  autoVal :: RevDiff a b c -> a
autoVal = RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | Derivative operator for a function with a specified argument type,
-- but with the value type derived automatically.
--
-- ==== __Examples__
--
-- >>> :{
--    sphericToVec :: (TrigField a) =>
--      (a, a) -> BoxedVector 3 a
--    sphericToVec (theta, phi) = DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta)
-- :}
--
-- >>> sphericToVec' = customArgDerivative tupleArg sphericToVec
--
-- Here 'tupleArg' indicates that the argument type is a tuple.
--
-- >>> sphericToVec' (0 :: Float, 0 :: Float)
-- Vector [(0.0,0.0),(0.0,1.0),(1.0,0.0)]
customArgDerivative ::
  (AutoDifferentiableValue c) =>
  (RevDiff (CT a) (CT a) a -> b) ->
  (b -> c) ->
  a ->
  DerivativeValue c
customArgDerivative :: forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a) (CT a) a -> b
arg = (RevDiff (CT a) (CT a) a -> b)
-> (c -> DerivativeValue c) -> (b -> c) -> a -> DerivativeValue c
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (CT a) (CT a) a -> b
arg c -> DerivativeValue c
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Derivative operator for a function with specified argument type
-- but automatically derived value type.
--
-- ==== __Examples__
--
-- >>> :{
--    sphericToVector :: (TrigField a) =>
--      (a, a) -> BoxedVector 3 a
--    sphericToVector (theta, phi) = DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta)
-- :}
--
-- >>> sphericToVector' = customValDerivative boxedVectorVal sphericToVector
--
-- The term 'boxedVectorVal' specifies that the output value type is a boxed vector.
--
-- >>> sphericToVector' (0 :: Float, 0 :: Float)
-- Vector [(0.0,0.0),(0.0,1.0),(1.0,0.0)]
customValDerivative ::
  ( DerivativeRoot b ~ CT (DerivativeArg b),
    DerivativeCoarg b ~ CT (DerivativeArg b),
    AutoDifferentiableArgument b
  ) =>
  (c -> d) ->
  (b -> c) ->
  DerivativeArg b ->
  d
customValDerivative :: forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative = (RevDiff
   (CT (DerivativeArg b)) (CT (DerivativeArg b)) (DerivativeArg b)
 -> b)
-> (c -> d) -> (b -> c) -> DerivativeArg b -> d
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (DerivativeRoot b) (DerivativeCoarg b) (DerivativeArg b)
-> b
RevDiff
  (CT (DerivativeArg b)) (CT (DerivativeArg b)) (DerivativeArg b)
-> b
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- Scalar

-- | Scalar (trivial) argument descriptor for differentiable functions.
--
-- ==== __Examples__
--
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
-- >>> import Debug.SimpleExpr (variable, SimpleExpr, simplify, SE)
--
-- >>> scalarArgDerivative = customArgDerivative scalarArg
--
-- >>> t = variable "t"
-- >>> :{
--   v :: SymbolicFunc  a => a -> BoxedVector 3 a
--   v t = DVGS.fromTuple (
--      unarySymbolicFunc "v_x" t,
--      unarySymbolicFunc "v_y" t,
--      unarySymbolicFunc "v_z" t
--    )
-- :}
--
-- >>> v t
-- Vector [v_x(t),v_y(t),v_z(t)]
--
-- >>> v' = simplify . scalarArgDerivative v :: SE -> BoxedVector 3 SE
-- >>> v' t
-- Vector [v_x'(t),v_y'(t),v_z'(t)]
scalarArg :: RevDiff a b c -> RevDiff a b c
scalarArg :: forall a b c. RevDiff a b c -> RevDiff a b c
scalarArg = RevDiff a b c -> RevDiff a b c
forall a. a -> a
id

-- | Derivative operator for a function from a scalar to any supported value type.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
--
-- >>> :{
--   f :: TrigField a => a -> (a, a)
--   f t = (cos t, sin t)
-- :}
--
-- >>> f' = scalarArgDerivative f
--
-- >>> f (0 :: Float)
-- (1.0,0.0)
-- >>> f' (0 :: Float)
-- (-0.0,1.0)
--
-- >>> t = variable "t"
-- >>> f t
-- (cos(t),sin(t))
-- >>> simplify $ scalarArgDerivative f t :: (SE, SE)
-- (-(sin(t)),cos(t))
scalarArgDerivative ::
  (AutoDifferentiableValue c) =>
  (RevDiff' a a -> c) ->
  a ->
  DerivativeValue c
scalarArgDerivative :: forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative = (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
 -> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a)
-> (c -> DerivativeValue c)
-> (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a -> c)
-> a
-> DerivativeValue c
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
forall a. a -> a
id c -> DerivativeValue c
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Derivative operator for a function
-- from any supported argument type to a scalar value.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
--
-- >>> :{
--   f :: Additive a => (a, a) -> a
--   f (x, y) = x + y
-- :}
--
-- >>> f (2 :: Float, 3 :: Float)
-- 5.0
-- >>> x = variable "x"
-- >>> y = variable "y"
-- >>> f (x, y)
-- x+y
--
-- >>> :{
--   f' :: (Additive a, Distributive (CT a)) => (a, a) -> (CT a, CT a)
--   f' = scalarValDerivative f
-- :}
--
-- >>> f' (2 :: Float, 3 :: Float)
-- (1.0,1.0)
-- >>> simplify $ f' (x, y) :: (SE, SE)
-- (1,1)
scalarValDerivative ::
  ( DerivativeRoot b ~ CT a,
    DerivativeCoarg b ~ CT a,
    DerivativeArg b ~ a,
    Multiplicative (CT c),
    AutoDifferentiableArgument b
  ) =>
  (b -> RevDiff d (CT c) c) ->
  a ->
  d
scalarValDerivative :: forall b a c d.
(DerivativeRoot b ~ CT a, DerivativeCoarg b ~ CT a,
 DerivativeArg b ~ a, Multiplicative (CT c),
 AutoDifferentiableArgument b) =>
(b -> RevDiff d (CT c) c) -> a -> d
scalarValDerivative = (RevDiff (CT a) (CT a) a -> b)
-> (RevDiff d (CT c) c -> d) -> (b -> RevDiff d (CT c) c) -> a -> d
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (DerivativeRoot b) (DerivativeCoarg b) (DerivativeArg b)
-> b
RevDiff (CT a) (CT a) a -> b
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiff d (CT c) c -> d
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- RevDiff type instances

-- | `RevDiff` instance for the `Show` typeclass.
instance (Show (b -> a), Show c) => Show (RevDiff a b c) where
  show :: RevDiff a b c -> String
show (MkRevDiff c
x b -> a
bp) = String
"MkRevDiff " String -> ShowS
forall a. [a] -> [a] -> [a]
++ c -> String
forall a. Show a => a -> String
show c
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (b -> a) -> String
forall a. Show a => a -> String
show b -> a
bp

-- | Typeclass for the automatic iterrupt of the backpropagation.
--
-- ==== __Examples__
--
-- >>> :{
--    simpleDerivative
--      (\x -> x * twoArgsDerivativeOverY (+) x (stopDiff (1 :: Float)))
--      (2024 :: Float)
-- :}
-- 1.0
class StopDiff a b where
  -- | Stops differentiation by converting a nested `RevDiff` type
  -- into a non-differentiable type.
  stopDiff :: a -> b

-- | Base case: stopping differentiation for the same type.
instance StopDiff a a where
  stopDiff :: a -> a
stopDiff = a -> a
forall a. a -> a
id

-- | Recursive case: stopping differentiation for `RevDiff` type.
instance
  (StopDiff a d, Additive b) =>
  StopDiff a (RevDiff b c d)
  where
  stopDiff :: a -> RevDiff b c d
stopDiff = d -> RevDiff b c d
forall a c b. Additive a => c -> RevDiff a b c
constDiff (d -> RevDiff b c d) -> (a -> d) -> a -> RevDiff b c d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> d
forall a b. StopDiff a b => a -> b
stopDiff

-- | Typeclass for creating constant differentiable functions.
class HasConstant a b c d where
  constant :: Proxy a -> b -> c -> d

-- | Base case: constant function for the same type.
instance HasConstant a b a b where
  constant :: Proxy a -> b -> a -> b
constant Proxy a
_ b
x a
_ = b
x

-- | Recursive case: constant function for `RevDiff` type.
instance
  forall a b c d e f t.
  (HasConstant a b c d, Additive t, e ~ CT c, f ~ CT d) =>
  HasConstant a b (RevDiff t e c) (RevDiff t f d)
  where
  constant :: Proxy a -> b -> RevDiff t e c -> RevDiff t f d
constant Proxy a
_ b
x (MkRevDiff c
v e -> t
_) = d -> RevDiff t f d
forall a c b. Additive a => c -> RevDiff a b c
constDiff (d -> RevDiff t f d) -> d -> RevDiff t f d
forall a b. (a -> b) -> a -> b
$ Proxy a -> b -> c -> d
forall a b c d. HasConstant a b c d => Proxy a -> b -> c -> d
constant (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) b
x c
v

-- | Differentiable version of sum `(+)` for the `RevDiff` type.
--
-- This function implements automatic differentiation for addition by applying
-- the sum rule:
-- \[
--  \frac{d}{dx} (f(x) + g(x)) = \frac{df(x)}{dx} + \frac{dg(x)}{dx}
-- \].
-- The gradient flows equally to
-- both operands during backpropagation.
differentiableSum ::
  (Additive c) =>
  RevDiff a (b, b) (c, c) ->
  RevDiff a b c
differentiableSum :: forall c a b.
Additive c =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSum (MkRevDiff (c
x0, c
x1) (b, b) -> a
bpc) =
  c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c
x0 c -> c -> c
forall a. Additive a => a -> a -> a
+ c
x1) (\b
cy -> (b, b) -> a
bpc (b
cy, b
cy))

-- | `Additive` instance for the `RevDiff` type.
instance
  (Additive a, Additive c) =>
  Additive (RevDiff a b c)
  where
  zero :: RevDiff a b c
zero = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff c
forall a. Additive a => a
zero
  + :: RevDiff a b c -> RevDiff a b c -> RevDiff a b c
(+) = RevDiff a (b, b) (c, c) -> RevDiff a b c
forall c a b.
Additive c =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSum (RevDiff a (b, b) (c, c) -> RevDiff a b c)
-> (RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c))
-> RevDiff a b c
-> RevDiff a b c
-> RevDiff a b c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

-- | Differentiable version of subtraction `(-)` for the `RevDiff` type.
--
-- Implements the difference rule:
-- \[
--  \frac{d}{dx} (f(x) - g(x)) = \frac{df(x)}{dx} - \frac{dg(x)}{dx}.
-- \]
-- Duringt the backpropagation, the gradient flows positively to the first operand
-- and negatively to the second operand.
differentiableSub ::
  (Subtractive b, Subtractive c) =>
  RevDiff a (b, b) (c, c) ->
  RevDiff a b c
differentiableSub :: forall b c a.
(Subtractive b, Subtractive c) =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSub (MkRevDiff (c
x0, c
x1) (b, b) -> a
bpc) =
  c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c
x0 c -> c -> c
forall a. Subtractive a => a -> a -> a
- c
x1) (\b
cy -> (b, b) -> a
bpc (b
cy, b -> b
forall a. Subtractive a => a -> a
negate b
cy))

-- | Differentiable version of sign change function `negate` for `RevDiff` type.
--
-- Implements the negation rule:
-- \[
--  \frac{d}{dx} (-f(x)) = -\frac{df(x)}{dx}.
-- \]
-- The gradient is simply
-- negated during backpropagation.
differentiableNegate ::
  (Subtractive a, Subtractive c) =>
  RevDiff a b c ->
  RevDiff a b c
differentiableNegate :: forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate (MkRevDiff c
x b -> a
bp) = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c -> c
forall a. Subtractive a => a -> a
negate c
x) (a -> a
forall a. Subtractive a => a -> a
negate (a -> a) -> (b -> a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
bp)

-- | `Subtractive` instance for the `RevDiff` type.
instance
  ( Additive a,
    Subtractive a,
    Subtractive b,
    Subtractive c
  ) =>
  Subtractive (RevDiff a b c)
  where
  negate :: RevDiff a b c -> RevDiff a b c
negate = RevDiff a b c -> RevDiff a b c
forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate
  (-) = RevDiff a (b, b) (c, c) -> RevDiff a b c
forall b c a.
(Subtractive b, Subtractive c) =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSub (RevDiff a (b, b) (c, c) -> RevDiff a b c)
-> (RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c))
-> RevDiff a b c
-> RevDiff a b c
-> RevDiff a b c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

-- | Differentiable version of commutative multiplication `(*)` for the `RevDiff` type.
--
-- Implements the product rule:
-- \[
--  \frac{d}{dx} (f(x) \cdot g(x)) = f(x) \cdot \frac{d g(x)}{dx} + \frac{df(x)}{dx} \cdot g(x).
-- \]
-- Each operand receives the gradient multiplied by the value of the other operand.
differentiableMult ::
  (Multiplicative b) =>
  RevDiff a (b, b) (b, b) ->
  RevDiff a b b
differentiableMult :: forall b a.
Multiplicative b =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableMult (MkRevDiff (b
x0, b
x1) (b, b) -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b
x0 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
x1) (\b
cy -> (b, b) -> a
bpc (b
x1 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b
x0 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))

-- | `Multiplicative` instance for the `RevDiff` type.
instance
  (Additive a, Multiplicative b) =>
  Multiplicative (RevDiff a b b)
  where
  one :: RevDiff a b b
one = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff b
forall a. Multiplicative a => a
one
  * :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(*) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
Multiplicative b =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableMult (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

instance
  (MultiplicativeAction Integer b, MultiplicativeAction Integer cb) =>
  MultiplicativeAction Integer (RevDiff ct cb b)
  where
  Integer
c *| :: Integer -> RevDiff ct cb b -> RevDiff ct cb b
*| (MkRevDiff b
x cb -> ct
bp) = b -> (cb -> ct) -> RevDiff ct cb b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (Integer
c Integer -> b -> b
forall a b. MultiplicativeAction a b => a -> b -> b
*| b
x) (cb -> ct
bp (cb -> ct) -> (cb -> cb) -> cb -> ct
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer
c Integer -> cb -> cb
forall a b. MultiplicativeAction a b => a -> b -> b
*|))

-- | Differentiable version of multiplicative action `(*|)` for the `RevDiff` type.
--
-- Implements the product rule for scalar \( f \)
-- and, for example, vector \( g_i \):
--
-- \[
--  \frac{d}{dx} \left( f(x) \cdot g_i(x) \right) =
--  f(x) \cdot \frac{d g_i(x)}{dx} + \frac{df(x)}{dx} \cdot g_i(x).
-- \]
-- Each operand receives the gradient multiplied by the value of the other operand.
differentiableMultAction ::
  (MultiplicativeAction a b, MultiplicativeAction a cb, Convolution b cb ca) =>
  RevDiff ct (ca, cb) (a, b) ->
  RevDiff ct cb b
differentiableMultAction :: forall a b cb ca ct.
(MultiplicativeAction a b, MultiplicativeAction a cb,
 Convolution b cb ca) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
differentiableMultAction (MkRevDiff (a
x, b
y) (ca, cb) -> ct
bpc) =
  b -> (cb -> ct) -> RevDiff ct cb b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (a
x a -> b -> b
forall a b. MultiplicativeAction a b => a -> b -> b
*| b
y) (\cb
cz -> (ca, cb) -> ct
bpc (b
y b -> cb -> ca
forall a b c. Convolution a b c => a -> b -> c
|*| cb
cz, a
x a -> cb -> cb
forall a b. MultiplicativeAction a b => a -> b -> b
*| cb
cz))

instance
  (MultiplicativeAction a b, MultiplicativeAction a cb, Convolution b cb ca, Additive ct) =>
  MultiplicativeAction (RevDiff ct ca a) (RevDiff ct cb b)
  where
  *| :: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct cb b
(*|) = RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
forall a b cb ca ct.
(MultiplicativeAction a b, MultiplicativeAction a cb,
 Convolution b cb ca) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
differentiableMultAction (RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b)
-> (RevDiff ct ca a
    -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b))
-> RevDiff ct ca a
-> RevDiff ct cb b
-> RevDiff ct cb b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

-- | Differentiable version of convolution `(|*|)` for the `RevDiff` type.
--
-- Implements the product rule for, for example, vectors
-- \( f_i \)
-- and
-- \( g_i \):
--
-- \[
--  \frac{d}{dx} \sum_i f_i(x) \cdot g_i(x) =
--  \sum_i f_i(x) \cdot \frac{d g_i(x)}{dx} + \frac{d f_i(x)}{dx} \cdot g_i(x)
-- \]
-- Each operand receives the gradient multiplied by the value of the other operand.
differentiableConv ::
  (Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
  RevDiff ct (ca, cb) (a, b) ->
  RevDiff ct cc c
differentiableConv :: forall a b c cc ca cb ct.
(Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
differentiableConv (MkRevDiff (a
x, b
y) (ca, cb) -> ct
bpc) =
  c -> (cc -> ct) -> RevDiff ct cc c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (a
x a -> b -> c
forall a b c. Convolution a b c => a -> b -> c
|*| b
y) (\cc
cz -> (ca, cb) -> ct
bpc (cc
cz cc -> b -> ca
forall a b c. Convolution a b c => a -> b -> c
|*| b
y, a
x a -> cc -> cb
forall a b c. Convolution a b c => a -> b -> c
|*| cc
cz))

instance
  (Convolution a b c, Convolution cc b ca, Convolution a cc cb, Additive ct) =>
  Convolution (RevDiff ct ca a) (RevDiff ct cb b) (RevDiff ct cc c)
  where
  |*| :: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct cc c
(|*|) = RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
forall a b c cc ca cb ct.
(Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
differentiableConv (RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c)
-> (RevDiff ct ca a
    -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b))
-> RevDiff ct ca a
-> RevDiff ct cb b
-> RevDiff ct cc c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

-- | Differentiable version of division `(/)` for the `RevDiff` type.
--
-- Implements the quotient rule:
-- \[
--  \frac{d}{dx} (f(x)/g(x)) =
--  \frac{\frac{df(x)}{dx} \cdot g(x) - f(x) \cdot \frac{dg(x)}{dx}}{g^2(x)}.
-- \]
-- The numerator receives gradient divided by the denominator, while the
-- denominator receives negative gradient scaled by the quotient divided by itself.
differentiableDiv ::
  (Subtractive b, Divisive b) =>
  RevDiff a (b, b) (b, b) ->
  RevDiff a b b
differentiableDiv :: forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (MkRevDiff (b
x0, b
x1) (b, b) -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b
x0 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1) (\b
cy -> (b, b) -> a
bpc (b
cy b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1, b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
x0 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))

-- | Differentiable version of `recip` for `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \frac{1}{f(x)} = -\frac{1}{f^2(x)} \cdot \frac{df(x)}{dx}.
-- \]
-- The gradient is scaled by the negative
-- square of the reciprocal.
differentiableRecip ::
  (Divisive b, Subtractive b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableRecip :: forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip (MkRevDiff b
x b -> a
bpc) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
r (b -> a
bpc (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
r b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))
  where
    r :: b
r = b -> b
forall a. Divisive a => a -> a
recip b
x

-- | `Divisive` instance for the `RevDiff` type.
instance
  (Additive a, Divisive b, Subtractive b, IntegerPower b) =>
  Divisive (RevDiff a b b)
  where
  recip :: RevDiff a b b -> RevDiff a b b
recip = RevDiff a b b -> RevDiff a b b
forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip
  / :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(/) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

-- | Differentiable version of exponentiation `(**)` for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} f^{g(x)}(x) = f^{g(x)}(x) \cdot (\log f(x) \cdot \frac{dg(x)}{dx} + \frac{g(x)}{f(x)} \cdot \frac{df(x)}{dx}),
-- \]
-- handling both base
-- and exponent dependencies in the gradient computation.
differentiablePow ::
  (ExpField b) =>
  RevDiff a (b, b) (b, b) ->
  RevDiff a b b
differentiablePow :: forall b a. ExpField b => RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiablePow (MkRevDiff (b
x, b
p) (b, b) -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
xp (\b
cy -> (b, b) -> a
bpc (b
p b -> b -> b
forall a. Multiplicative a => a -> a -> a
* (b
x b -> b -> b
forall a. ExpField a => a -> a -> a
** (b
p b -> b -> b
forall a. Subtractive a => a -> a -> a
- b
forall a. Multiplicative a => a
one)) b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. ExpField a => a -> a
log b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
xp b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
  where
    xp :: b
xp = b
x b -> b -> b
forall a. ExpField a => a -> a -> a
** b
p

-- | Differentiable version of `exp` for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \exp{f(x)} = \exp{f(x)} \cdot \frac{df(x)}{dx}.
-- \]
-- The exponential function is its own derivative,
-- making the gradient computation particularly elegant.
differentiableExp ::
  (ExpField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableExp :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableExp (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (b -> a
bp (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
y b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))
  where
    y :: b
y = b -> b
forall a. ExpField a => a -> a
exp b
x

-- | Differentiable version of natural logarithm for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \log \left| f(x) \right| = \frac{1}{f(x)} \cdot \frac{df(x)}{dx}.
-- \]
-- For real numbers, this computes
-- the derivative of
-- \(\log |x|\),
-- which is defined for all non-zero values.
--
-- Unsafety note: This function and derivative will raise an error if @f@ is zero, as the
-- logarithm and `recip` from @numhask@ is undefined at zero point.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
--
-- >>> simplify $ simpleDerivative differentiableLog (variable "x") :: SE
-- 1/x
differentiableLog ::
  (ExpField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableLog :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableLog (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
forall a. ExpField a => a -> a
log b
x) (b -> a
bp (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x))

-- | Differentiable version of `logBase` for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \log_b f(x)
-- \]
-- where both base and argument may be differentiable.
-- Uses the change of base formula and applies the chain rule appropriately.
differentiableLogBase ::
  (ExpField b, IntegerPower b) =>
  RevDiff a (b, b) (b, b) ->
  RevDiff a b b
differentiableLogBase :: forall b a.
(ExpField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableLogBase (MkRevDiff (b
b, b
x) (b, b) -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
    (b
logX b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
logB)
    (\b
cy -> (b, b) -> a
bpc (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
logX b -> b -> b
forall a. Divisive a => a -> a -> a
/ (b
logB b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2) b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
b b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. Divisive a => a -> a
recip b
x b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
logB b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
  where
    logX :: b
logX = b -> b
forall a. ExpField a => a -> a
log b
x
    logB :: b
logB = b -> b
forall a. ExpField a => a -> a
log b
b

-- | Differentiable version of `sqrt` for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \sqrt{f(x)} = \frac{1}{2 \sqrt {f(x)}} \cdot \frac{df(x)}{dx}.
-- The gradient is scaled by the
-- reciprocal of twice the square root of the input.
differentiableSqrt ::
  (ExpField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableSqrt :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableSqrt (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (\b
cy -> b -> a
bp (b -> a) -> b -> a
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Divisive a => a -> a
recip (b
forall a. (Multiplicative a, Additive a) => a
two b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
y) b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy)
  where
    y :: b
y = b -> b
forall a. ExpField a => a -> a
sqrt b
x

-- | `ExpField` instance for the `RevDiff` type.
instance
  (ExpField b, Additive a, Subtractive a, IntegerPower b) =>
  ExpField (RevDiff a b b)
  where
  exp :: RevDiff a b b -> RevDiff a b b
exp = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableExp
  log :: RevDiff a b b -> RevDiff a b b
log = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableLog
  ** :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(**) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a. ExpField b => RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiablePow (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
  logBase :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
logBase = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(ExpField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableLogBase (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
  sqrt :: RevDiff a b b -> RevDiff a b b
sqrt = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableSqrt

-- | Differentiable version of `atan2` for the `RevDiff` type.
--
-- Computes the two-argument arctangent function:
-- \[
--  \mathrm{arctg2}(y, x) = \arctg\left(\frac{y}{x}\right)
-- \]
--
-- The gradient computation accounts for both arguments using the formula:
-- \[
--  \frac{d}{dx} \mathrm{arctg2}(f(x),g(x)) =
--  - \frac{g(x)}{f(x)^2+g(x)^2} \cdot \frac{df(x)}{dx}
--  + \frac{f(x)}{f(x)^2+g(x)^2} \cdot \frac{dg(x)}{dx}
-- \]
differentiableAtan2 ::
  (TrigField b, IntegerPower b) =>
  RevDiff a (b, b) (b, b) ->
  RevDiff a b b
differentiableAtan2 :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableAtan2 (MkRevDiff (b
y, b
x) (b, b) -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
    (b -> b -> b
forall a. TrigField a => a -> a -> a
atan2 b
y b
x)
    (\b
cy -> (b, b) -> a
bpc (b
x b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
r2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
y b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
r2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
  where
    r2 :: b
r2 = b
x b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2 b -> b -> b
forall a. Additive a => a -> a -> a
+ b
y b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2

instance
  ( AlgebraicPower Int a,
    MultiplicativeAction Int a,
    Multiplicative a
  ) =>
  AlgebraicPower Int (RevDiff c a a)
  where
  RevDiff c a a
x ^^ :: RevDiff c a a -> Int -> RevDiff c a a
^^ Int
n = RevDiff c a a -> RevDiff c a a
f RevDiff c a a
x -- differentiablePow .: twoArgsToTuple
    where
      f :: RevDiff c a a -> RevDiff c a a
f =
        (a -> a) -> (a -> a) -> RevDiff c a a -> RevDiff c a a
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc
          (a -> Int -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ Int
n)
          (\a
x' -> Int
n Int -> a -> a
forall a b. MultiplicativeAction a b => a -> b -> b
*| (a
x' a -> Int -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ (Int
n Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)))

-- (fromIntegral n * integralPow (n - one))

instance
  ( AlgebraicPower Integer a,
    MultiplicativeAction Integer a,
    Multiplicative a
  ) =>
  AlgebraicPower Integer (RevDiff c a a)
  where
  RevDiff c a a
x ^^ :: RevDiff c a a -> Integer -> RevDiff c a a
^^ Integer
n = RevDiff c a a -> RevDiff c a a
f RevDiff c a a
x
    where
      f :: RevDiff c a a -> RevDiff c a a
f =
        (a -> a) -> (a -> a) -> RevDiff c a a -> RevDiff c a a
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc
          (a -> Integer -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ Integer
n)
          (\a
x' -> Integer
n Integer -> a -> a
forall a b. MultiplicativeAction a b => a -> b -> b
*| (a
x' a -> Integer -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ (Integer
n Integer -> Integer -> Integer
forall a. Subtractive a => a -> a -> a
- Integer
1)))

-- | Differentiable version of sine function for the `RevDiff` type.
--
-- Implements
-- \[
-- d\frac{d}{dx} \sin f(x) = \cos f(x) * \frac{df(x)}{dx}
-- \]
-- using the standard trigonometric derivative.
--
-- ==== __Examples__
--
-- >>> call differentiableSin 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableSin 0.0 :: Float
-- 1.0
differentiableSin ::
  (TrigField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableSin :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSin = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
sin b -> b
forall a. TrigField a => a -> a
cos

-- | Differentiable version of cosine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \cos f(x) = -\sin f(x) \cdot \frac{df(x)}{dx}
-- \]
-- using the standard trigonometric derivative.
--
-- ==== __Examples__
--
-- >>> call differentiableCos 0.0 :: Float
-- 1.0
-- >>> simpleDerivative differentiableCos 0.0 :: Float
-- -0.0
differentiableCos ::
  (TrigField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableCos :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCos = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
cos (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
sin)

-- | Differentiable version of tangent function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d]{dx} \tg f(x) =
--  \sec^2 f(x) * \frac{df(x)}{dx} = \frac{1}{cos^2 f(x)} \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableTan 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableTan 0.0 :: Float
-- 1.0
differentiableTan ::
  (TrigField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableTan :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTan = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
tan ((b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ (-Integer
2)) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
cos)

-- | Differentiable version of arcsine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \arcsin f(x) = \frac{1}{\sqrt{1-f^2(x)}} \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableAsin 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableAsin 0.0 :: Float
-- 1.0
differentiableAsin ::
  (TrigField b, ExpField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAsin :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsin = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
asin (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | Differentiable version of arccosine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \arccos f(x) = -\frac{1}{\sqrt{1-f^2(x)}} \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableAcos 0.0 :: Float
-- 1.5707964
-- >>> simpleDerivative differentiableAcos 0.0 :: Float
-- -1.0
differentiableAcos ::
  (TrigField b, ExpField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAcos :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcos =
  (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
acos (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | Differentiable version of arctangent function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \mathrm{arctg} f(x) = \frac{1}{1 + f^2(x)} \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableAtan 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableAtan 0.0 :: Float
-- 1.0
differentiableAtan ::
  (TrigField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAtan :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtan = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
atan (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Additive a => a -> a -> a
+) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | Differentiable version of hyperbolic sine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \sinh f(x) = \cosh f(x) \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableSinh 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableSinh 0.0 :: Float
-- 1.0
differentiableSinh ::
  (TrigField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableSinh :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSinh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
sinh b -> b
forall a. TrigField a => a -> a
cosh

-- | Differentiable version of hyperbolic cosine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \mathrm{csch} f(x) = \mathrm{sh} f(x) \cdot \frac{df(x)}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableCosh 0.0 :: Float
-- 1.0
-- >>> simpleDerivative differentiableCosh 0.0 :: Float
-- 0.0
differentiableCosh ::
  (TrigField b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableCosh :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCosh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
cosh b -> b
forall a. TrigField a => a -> a
sinh

-- | Differentiable version of hyperbolic tangent function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \mathrm{th} f(x) =
--  \mathrm{sech}^2 f(x) \cdot \frac{df}{dx} = \frac{1}{\mathrm{ch}^2 f(x)} \cdot \frac{df}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableTanh 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableTanh 0.0 :: Float
-- 1.0
differentiableTanh ::
  (TrigField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableTanh :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTanh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
tanh ((b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ (-Integer
2)) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
cosh)

-- | Differentiable version of inverse hyperbolic sine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \DeclareMathOperator{\arcsh}{arcsh}
--  \frac{d}{dx} \arcsh f(x) = \frac{1}{\sqrt{1 + f^2 (x)}} \cdot \frac{df}{dx}.
-- \]
--
-- ==== __Examples__
--
-- >>> call differentiableAsinh 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableAsinh 0.0 :: Float
-- 1.0
differentiableAsinh ::
  (TrigField b, ExpField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAsinh :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsinh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
asinh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Additive a => a -> a -> a
+) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | Differentiable version of inverse hyperbolic cosine function for the `RevDiff` type.
--
-- Implements
-- \[
--  \DeclareMathOperator{\arcch}{arcch}
--  \frac{d}{dx} \arcch f(x) = \frac{1}{f^2(x) - 1} \cdot \frac{df}{dx}.
-- \]
differentiableAcosh ::
  (TrigField b, ExpField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAcosh :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcosh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
acosh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | Differentiable version of inverse hyperbolic tangent function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \arcth f(x) = \frac{1}{1 - f^2 (x)} \cdot \frac{df}{dx}.
--
-- ==== __Examples__
--
-- >>> call differentiableAtanh 0.0 :: Float
-- 0.0
-- >>> simpleDerivative differentiableAtanh 0.0 :: Float
-- 1.0
differentiableAtanh ::
  (TrigField b, IntegerPower b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAtanh :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtanh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
atanh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))

-- | `TrigField` instance for the `RevDiff` type.
instance
  (Additive a, Subtractive a, ExpField b, TrigField b, IntegerPower b) =>
  TrigField (RevDiff a b b)
  where
  -- Constants
  pi :: RevDiff a b b
pi = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff b
forall a. TrigField a => a
pi

  -- Basic trig functions
  sin :: RevDiff a b b -> RevDiff a b b
sin = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSin
  cos :: RevDiff a b b -> RevDiff a b b
cos = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCos
  tan :: RevDiff a b b -> RevDiff a b b
tan = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTan

  -- Inverse trig functions
  asin :: RevDiff a b b -> RevDiff a b b
asin = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsin
  acos :: RevDiff a b b -> RevDiff a b b
acos = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcos
  atan :: RevDiff a b b -> RevDiff a b b
atan = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtan
  atan2 :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
atan2 = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableAtan2 (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple

  -- Hyperbolic functions
  sinh :: RevDiff a b b -> RevDiff a b b
sinh = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSinh
  cosh :: RevDiff a b b -> RevDiff a b b
cosh = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCosh
  tanh :: RevDiff a b b -> RevDiff a b b
tanh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTanh

  -- Inverse hyperbolic functions
  asinh :: RevDiff a b b -> RevDiff a b b
asinh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsinh
  acosh :: RevDiff a b b -> RevDiff a b b
acosh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcosh
  atanh :: RevDiff a b b -> RevDiff a b b
atanh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtanh

-- | Differentiable version of absolute value function for the `RevDiff` type.
--
-- Implements
-- \[
--  \frac{d}{dx} \left_| f(x) \right_| = \sign(f) \cdot \frac{df}{dx},
-- \]
-- where \( \sign(f) \) is the signum function.
-- The derivative is undefined at zero but returns zero in this implementation.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
--
-- >>> simplify $ simpleDerivative differentiableAbs (variable "x") :: SE
-- sign(x)
--
-- >>> simpleDerivative differentiableAbs (10 :: Float) :: Float
-- 1.0
-- >>> simpleDerivative differentiableAbs (-10 :: Float) :: Float
-- -1.0
-- >>> simpleDerivative differentiableAbs (0 :: Float) :: Float
-- 0.0
differentiableAbs ::
  (GHCN.Num b, Multiplicative b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableAbs :: forall b a.
(Num b, Multiplicative b) =>
RevDiff a b b -> RevDiff a b b
differentiableAbs (MkRevDiff b
x b -> a
bpc) =
  b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
forall a. Num a => a -> a
GHCN.abs b
x) (b -> a
bpc (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b
forall a. Num a => a -> a
GHCN.signum b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))

-- | Differentiable version of signum function for the `RevDiff` type.
--
-- The signum function has derivative zero everywhere except at zero (where it's undefined).
-- This implementation returns zero for all inputs, including zero.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
--
-- >>> simplify $ simpleDerivative differentiableSign (variable "x") :: SE
-- 0
--
-- >>> simpleDerivative differentiableSign (10 :: Float) :: Float
-- 0.0
-- >>> simpleDerivative differentiableSign (-10 :: Float) :: Float
-- 0.0
-- >>> simpleDerivative differentiableSign (0 :: Float) :: Float
-- 0.0
differentiableSign ::
  (Additive a, GHCN.Num b) =>
  RevDiff a b b ->
  RevDiff a b b
differentiableSign :: forall a b. (Additive a, Num b) => RevDiff a b b -> RevDiff a b b
differentiableSign (MkRevDiff b
x b -> a
_) = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b) -> b -> RevDiff a b b
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Num a => a -> a
GHCN.signum b
x

-- | `RevDiff` instance for the `GHC.Num.Num` typeclass.
instance
  ( Additive a,
    Subtractive a,
    GHCN.Num b,
    Subtractive b,
    Multiplicative b
  ) =>
  GHCN.Num (RevDiff a b b)
  where
  + :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(+) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.+)
  (-) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.-)
  * :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(*) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.*)
  negate :: RevDiff a b b -> RevDiff a b b
negate = RevDiff a b b -> RevDiff a b b
forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate
  abs :: RevDiff a b b -> RevDiff a b b
abs = RevDiff a b b -> RevDiff a b b
forall b a.
(Num b, Multiplicative b) =>
RevDiff a b b -> RevDiff a b b
differentiableAbs
  signum :: RevDiff a b b -> RevDiff a b b
signum = RevDiff a b b -> RevDiff a b b
forall a b. (Additive a, Num b) => RevDiff a b b -> RevDiff a b b
differentiableSign
  fromInteger :: Integer -> RevDiff a b b
fromInteger = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b) -> (Integer -> b) -> Integer -> RevDiff a b b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> b
forall a. Num a => Integer -> a
GHCN.fromInteger

-- | `RevDiff` instance of the `NumHask.Data.Integral.FromInteger` typeclass.
instance
  (FromInteger c, Additive a) =>
  FromInteger (RevDiff a b c)
  where
  fromInteger :: Integer -> RevDiff a b c
fromInteger = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Integer -> c) -> Integer -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> c
forall a. FromInteger a => Integer -> a
fromInteger

-- | `RevDiff` and `Int8` instance of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Int8, Additive a) =>
  FromIntegral (RevDiff a b c) Int8
  where
  fromIntegral :: Int8 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int8 -> c) -> Int8 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int8 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Int16` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Int16, Additive a) =>
  FromIntegral (RevDiff a b c) Int16
  where
  fromIntegral :: Int16 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int16 -> c) -> Int16 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int16 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Int32` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Int32, Additive a) =>
  FromIntegral (RevDiff a b c) Int32
  where
  fromIntegral :: Int32 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int32 -> c) -> Int32 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Int64` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Int64, Additive a) =>
  FromIntegral (RevDiff a b c) Int64
  where
  fromIntegral :: Int64 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int64 -> c) -> Int64 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Int` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Int, Additive a) =>
  FromIntegral (RevDiff a b c) Int
  where
  fromIntegral :: Int -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int -> c) -> Int -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Word8` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Word8, Additive a) =>
  FromIntegral (RevDiff a b c) Word8
  where
  fromIntegral :: Word8 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word8 -> c) -> Word8 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Word16` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Word16, Additive a) =>
  FromIntegral (RevDiff a b c) Word16
  where
  fromIntegral :: Word16 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word16 -> c) -> Word16 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Word32` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Word32, Additive a) =>
  FromIntegral (RevDiff a b c) Word32
  where
  fromIntegral :: Word32 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word32 -> c) -> Word32 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Word64` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Word64, Additive a) =>
  FromIntegral (RevDiff a b c) Word64
  where
  fromIntegral :: Word64 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word64 -> c) -> Word64 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Word` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Word, Additive a) =>
  FromIntegral (RevDiff a b c) Word
  where
  fromIntegral :: Word -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word -> c) -> Word -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Integer` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Integer, Additive a) =>
  FromIntegral (RevDiff a b c) Integer
  where
  fromIntegral :: Integer -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Integer -> c) -> Integer -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` and `Natural` instance
-- of the `NumHask.Data.Integral.FromIntegral` typeclass.
instance
  (FromIntegral c Natural, Additive a) =>
  FromIntegral (RevDiff a b c) Natural
  where
  fromIntegral :: Natural -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Natural -> c) -> Natural -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral

-- | `RevDiff` instance for the `GHC.Real.Fractional` typeclass.
instance
  ( Additive a,
    Subtractive a,
    Subtractive b,
    Divisive b,
    GHCR.Fractional b,
    IntegerPower b
  ) =>
  GHCR.Fractional (RevDiff a b b)
  where
  / :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(/) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
  recip :: RevDiff a b b -> RevDiff a b b
recip = RevDiff a b b -> RevDiff a b b
forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip
  fromRational :: Rational -> RevDiff a b b
fromRational = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b)
-> (Rational -> b) -> Rational -> RevDiff a b b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> b
forall a. Fractional a => Rational -> a
GHCR.fromRational

-- | Transforms two `RevDiff` instances into a 'RevDiff' instances with a tuple.
-- Inverese operation is 'tupleArg'.
twoArgsToTuple ::
  (Additive a) =>
  RevDiff a b0 c0 ->
  RevDiff a b1 c1 ->
  RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple :: forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple (MkRevDiff c0
x0 b0 -> a
bpc0) (MkRevDiff c1
x1 b1 -> a
bpc1) =
  (c0, c1) -> ((b0, b1) -> a) -> RevDiff a (b0, b1) (c0, c1)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c0
x0, c1
x1) (\(b0
cy0, b1
cy1) -> b0 -> a
bpc0 b0
cy0 a -> a -> a
forall a. Additive a => a -> a -> a
+ b1 -> a
bpc1 b1
cy1)

-- | Tuple argument descriptor for differentiable functions.
-- Transforms a `RevDiff` instances of a tuple into a tuple of `RevDiff` instances.
-- This allows applying differentiable operations to both elements of the tuple.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   f :: Multiplicative a => (a, a) -> a
--   f (x, y) = x * y
-- :}
--
-- >>> :{
--   f' :: (Distributive a, CT a ~ a) => (a, a) -> (a, a)
--   f' = customArgDerivative tupleArg f
-- :}
--
-- >>> simplify $ f' (variable "x", variable "y")
-- (y,x)
tupleArg ::
  (Additive b0, Additive b1) =>
  RevDiff a (b0, b1) (c0, c1) ->
  (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg :: forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg (MkRevDiff (c0
x0, c1
x1) (b0, b1) -> a
bpc) =
  ( c0 -> (b0 -> a) -> RevDiff a b0 c0
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c0
x0 (\b0
cy -> (b0, b1) -> a
bpc (b0
cy, b1
forall a. Additive a => a
zero)),
    c1 -> (b1 -> a) -> RevDiff a b1 c1
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c1
x1 (\b1
cy -> (b0, b1) -> a
bpc (b0
forall a. Additive a => a
zero, b1
cy))
  )

-- | Tuple argument descriptor builder.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:sophisticated-45-argument-45-function-45-how-45-it-45-works)
-- for details and examples.
mkTupleArg ::
  (Additive b0, Additive b1) =>
  RevDiffArg a b0 c0 d0 ->
  RevDiffArg a b1 c1 d1 ->
  RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg :: forall b0 b1 a c0 d0 c1 d1.
(Additive b0, Additive b1) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1 -> RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 = RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> (RevDiff a b0 c0, RevDiff a b1 c1)
-> (d0, d1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
cross RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 ((RevDiff a b0 c0, RevDiff a b1 c1) -> (d0, d1))
-> (RevDiff a (b0, b1) (c0, c1)
    -> (RevDiff a b0 c0, RevDiff a b1 c1))
-> RevDiff a (b0, b1) (c0, c1)
-> (d0, d1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg

-- | Tuple instance for `AutoDifferentiableArgument` typeclass.
-- It makes it possible to differntiate tuple argument funcitons.
instance
  ( AutoDifferentiableArgument a0,
    AutoDifferentiableArgument a1,
    DerivativeRoot a0 ~ DerivativeRoot a1
  ) =>
  AutoDifferentiableArgument (a0, a1)
  where
  type DerivativeRoot (a0, a1) = DerivativeRoot a0
  type DerivativeCoarg (a0, a1) = (DerivativeCoarg a0, DerivativeCoarg a1)
  type DerivativeArg (a0, a1) = (DerivativeArg a0, DerivativeArg a1)
  autoArg :: RevDiff (DerivativeRoot a0) (DerivativeCoarg a0, DerivativeCoarg a1) (DerivativeArg a0, DerivativeArg a1) -> (a0, a1)
  autoArg :: RevDiff
  (DerivativeRoot a0)
  (DerivativeCoarg a0, DerivativeCoarg a1)
  (DerivativeArg a0, DerivativeArg a1)
-> (a0, a1)
autoArg = RevDiffArg
  (DerivativeRoot a1) (DerivativeCoarg a0) (DerivativeArg a0) a0
-> RevDiffArg
     (DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1) a1
-> RevDiffArg
     (DerivativeRoot a1)
     (DerivativeCoarg a0, DerivativeCoarg a1)
     (DerivativeArg a0, DerivativeArg a1)
     (a0, a1)
forall b0 b1 a c0 d0 c1 d1.
(Additive b0, Additive b1) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1 -> RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg RevDiff (DerivativeRoot a0) (DerivativeCoarg a0) (DerivativeArg a0)
-> a0
RevDiffArg
  (DerivativeRoot a1) (DerivativeCoarg a0) (DerivativeArg a0) a0
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiffArg
  (DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1) a1
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- | Tuple differentiable value builder
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
mkTupleVal :: (a0 -> b0) -> (a1 -> b1) -> (a0, a1) -> (b0, b1)
mkTupleVal :: forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal = (a0 -> b0) -> (a1 -> b1) -> (a0, a1) -> (b0, b1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
cross

-- | Tuple differentiable value descriptor.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
tupleVal ::
  (Multiplicative b0, Multiplicative b1) =>
  (RevDiff a0 b0 c0, RevDiff a1 b1 c1) ->
  (a0, a1)
tupleVal :: forall b0 b1 a0 c0 a1 c1.
(Multiplicative b0, Multiplicative b1) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1) -> (a0, a1)
tupleVal = (RevDiff a0 b0 c0 -> a0)
-> (RevDiff a1 b1 c1 -> a1)
-> (RevDiff a0 b0 c0, RevDiff a1 b1 c1)
-> (a0, a1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal RevDiff a0 b0 c0 -> a0
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a1 b1 c1 -> a1
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | Tuple instance for `AutoDifferentiableValue` typeclass.
-- It makes it possible to differntiate tuple value funcitons.
instance
  (AutoDifferentiableValue a0, AutoDifferentiableValue a1) =>
  AutoDifferentiableValue (a0, a1)
  where
  type DerivativeValue (a0, a1) = (DerivativeValue a0, DerivativeValue a1)
  autoVal :: (a0, a1) -> (DerivativeValue a0, DerivativeValue a1)
  autoVal :: (a0, a1) -> (DerivativeValue a0, DerivativeValue a1)
autoVal = (a0 -> DerivativeValue a0)
-> (a1 -> DerivativeValue a1)
-> (a0, a1)
-> (DerivativeValue a0, DerivativeValue a1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal a0 -> DerivativeValue a0
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a1 -> DerivativeValue a1
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Differentiable operator for functions with tuple argument
-- and any supported by `AutoDifferentiableValue` value type.
-- This function is equivalent to 'twoArgsDerivative' up to the curring.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   f :: SymbolicFunc a => a -> a
--   f = unarySymbolicFunc "f"
--   g :: SymbolicFunc a => a -> a
--   g = unarySymbolicFunc "g"
--   h :: (SymbolicFunc a, Multiplicative a) => (a, a) -> a
--   h (x, y) = f x * g y
-- :}
--
-- >>> f(x)*g(y)
-- f(x)*g(y)
--
-- >>> :{
--  h' :: (SE, SE) -> (SE, SE)
--  h' = simplify . tupleArgDerivative h
-- :}
--
-- >>> h' (x, y)
-- (f'(x)*g(y),g'(y)*f(x))
--
-- >>> :{
--  h'' :: (SE, SE) -> ((SE, SE), (SE, SE))
--  h'' = simplify . tupleArgDerivative (tupleArgDerivative h)
-- :}
--
-- >>> h'' (x, y)
-- ((f''(x)*g(y),g'(y)*f'(x)),(f'(x)*g'(y),g''(y)*f(x)))
tupleArgDerivative ::
  (Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
  ((RevDiff' (a0, a1) a0, RevDiff' (a0, a1) a1) -> b) ->
  (a0, a1) ->
  DerivativeValue b
tupleArgDerivative :: forall a0 a1 b.
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' (a0, a1) a0, RevDiff' (a0, a1) a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleArgDerivative = (RevDiff
   (Dual (Tangent (a0, a1))) (Dual (Tangent (a0, a1))) (a0, a1)
 -> (RevDiff (CT a0, CT a1) (CT a0) a0,
     RevDiff (CT a0, CT a1) (CT a1) a1))
-> ((RevDiff (CT a0, CT a1) (CT a0) a0,
     RevDiff (CT a0, CT a1) (CT a1) a1)
    -> b)
-> (a0, a1)
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
    RevDiff (CT a0, CT a1) (CT a1) a1)
RevDiff
  (Dual (Tangent (a0, a1))) (Dual (Tangent (a0, a1))) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
    RevDiff (CT a0, CT a1) (CT a1) a1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg

-- | Differentiable operator for functions over tuple argument
-- with respect to the first argument.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   f :: SymbolicFunc a => a -> a
--   f = unarySymbolicFunc "f"
--   g :: SymbolicFunc a => a -> a
--   g = unarySymbolicFunc "g"
--   h :: (SymbolicFunc a, Multiplicative a) => (a, a) -> a
--   h (x, y) = f x * g y
-- :}
--
-- >>> f(x)*g(y)
-- f(x)*g(y)
--
-- >>> :{
--  h' :: (SE, SE) -> SE
--  h' = simplify . tupleDerivativeOverX h
-- :}
--
-- >>> h' (x, y)
-- f'(x)*g(y)
--
-- >>> :{
--  h'' :: (SE, SE) -> SE
--  h'' = simplify . tupleDerivativeOverX (tupleDerivativeOverX h)
-- :}
--
-- >>> h'' (x, y)
-- f''(x)*g(y)
tupleDerivativeOverX ::
  (AutoDifferentiableValue b, Additive (CT a0)) =>
  ((RevDiff' a0 a0, RevDiff' a0 a1) -> b) ->
  (a0, a1) ->
  DerivativeValue b
tupleDerivativeOverX :: forall b a0 a1.
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleDerivativeOverX (RevDiff' a0 a0, RevDiff' a0 a1) -> b
f (a0
x0, a1
x1) =
  (RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a0 a0
x -> (RevDiff' a0 a0, RevDiff' a0 a1) -> b
f (RevDiff' a0 a0
x, a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)) a0
x0

-- | Differentiable operator for functions over tuple argument
-- with respect to the second argument.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   f :: SymbolicFunc a => a -> a
--   f = unarySymbolicFunc "f"
--   g :: SymbolicFunc a => a -> a
--   g = unarySymbolicFunc "g"
--   h :: (SymbolicFunc a, Multiplicative a) => (a, a) -> a
--   h (x, y) = f x * g y
-- :}
--
-- >>> f(x)*g(y)
-- f(x)*g(y)
--
-- >>> :{
--  h' :: (SE, SE) -> SE
--  h' = simplify . tupleDerivativeOverY h
-- :}
--
-- >>> h' (x, y)
-- g'(y)*f(x)
--
-- >>> :{
--  h'' :: (SE, SE) -> SE
--  h'' = simplify . tupleDerivativeOverY (tupleDerivativeOverY h)
-- :}
--
-- >>> h'' (x, y)
-- g''(y)*f(x)
tupleDerivativeOverY ::
  (Additive (CT a1), AutoDifferentiableValue b) =>
  ((RevDiff' a1 a0, RevDiff' a1 a1) -> b) ->
  (a0, a1) ->
  DerivativeValue b
tupleDerivativeOverY :: forall a1 b a0.
(Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' a1 a0, RevDiff' a1 a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleDerivativeOverY (RevDiff' a1 a0, RevDiff' a1 a1) -> b
f (a0
x0, a1
x1) =
  (RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a1 a1
x -> (RevDiff' a1 a0, RevDiff' a1 a1) -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, RevDiff' a1 a1
x)) a1
x1

-- | Differentiable operator for functions over two arguments
-- and any supported by 'AutoDifferentiableValue' value type.
-- Equivalent to 'tupleArgDerivative' up to the curring.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   f :: SymbolicFunc a => a -> a
--   f = unarySymbolicFunc "f"
--   g :: SymbolicFunc a => a -> a
--   g = unarySymbolicFunc "g"
--   h :: (SymbolicFunc a, Multiplicative a) => a -> a -> a
--   h x y = f x * g y
-- :}
--
-- >>> f(x)*g(y)
-- f(x)*g(y)
--
-- >>> :{
--  h' :: SE -> SE -> (SE, SE)
--  h' = simplify . twoArgsDerivative h
-- :}
--
-- >>> h' x y
-- (f'(x)*g(y),g'(y)*f(x))
--
-- >>> :{
--  h'' :: SE -> SE -> ((SE, SE), (SE, SE))
--  h'' = simplify . twoArgsDerivative (twoArgsDerivative h)
-- :}
--
-- >>> h'' x y
-- ((f''(x)*g(y),g'(y)*f'(x)),(f'(x)*g'(y),g''(y)*f(x)))
twoArgsDerivative ::
  (Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
  (RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b) ->
  a0 ->
  a1 ->
  DerivativeValue b
twoArgsDerivative :: forall a0 a1 b.
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivative RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b
f = ((a0, a1) -> DerivativeValue b) -> a0 -> a1 -> DerivativeValue b
forall a b c. ((a, b) -> c) -> a -> b -> c
curry ((RevDiff' (a0, a1) (a0, a1) -> b) -> (a0, a1) -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' (a0, a1) (a0, a1) -> b)
 -> (a0, a1) -> DerivativeValue b)
-> (RevDiff' (a0, a1) (a0, a1) -> b)
-> (a0, a1)
-> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ (RevDiff (CT a0, CT a1) (CT a0) a0
 -> RevDiff (CT a0, CT a1) (CT a1) a1 -> b)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
    RevDiff (CT a0, CT a1) (CT a1) a1)
-> b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RevDiff (CT a0, CT a1) (CT a0) a0
-> RevDiff (CT a0, CT a1) (CT a1) a1 -> b
RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b
f ((RevDiff (CT a0, CT a1) (CT a0) a0,
  RevDiff (CT a0, CT a1) (CT a1) a1)
 -> b)
-> (RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
    -> (RevDiff (CT a0, CT a1) (CT a0) a0,
        RevDiff (CT a0, CT a1) (CT a1) a1))
-> RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
    RevDiff (CT a0, CT a1) (CT a1) a1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg)

-- | Differentiable operator for functions over two arguments
-- with respect to the first argument.
-- Equivalent to `tupleDerivativeOverX` up to the curring.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   f :: SymbolicFunc a => a -> a
--   f = unarySymbolicFunc "f"
--   g :: SymbolicFunc a => a -> a
--   g = unarySymbolicFunc "g"
--   h :: (SymbolicFunc a, Multiplicative a) => a -> a -> a
--   h x y = f x * g y
-- :}
--
-- >>> f(x)*g(y)
-- f(x)*g(y)
--
-- >>> :{
--  h' :: SE -> SE -> SE
--  h' = simplify . twoArgsDerivativeOverX h
-- :}
--
-- >>> h' x y
-- f'(x)*g(y)
--
-- >>> :{
--  h'' :: SE -> SE -> SE
--  h'' = simplify . twoArgsDerivativeOverX (twoArgsDerivativeOverX h)
-- :}
--
-- >>> h'' x y
-- f''(x)*g(y)
twoArgsDerivativeOverX ::
  (Additive (CT a0), AutoDifferentiableValue b) =>
  (RevDiff' a0 a0 -> RevDiff' a0 a1 -> b) ->
  a0 ->
  a1 ->
  DerivativeValue b
twoArgsDerivativeOverX :: forall a0 b a1.
(Additive (CT a0), AutoDifferentiableValue b) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivativeOverX RevDiff' a0 a0 -> RevDiff' a0 a1 -> b
f a0
x0 a1
x1 =
  (RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a0 a0
x -> RevDiff' a0 a0 -> RevDiff' a0 a1 -> b
f RevDiff' a0 a0
x (a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)) a0
x0

-- | Differentiable operator for functions over two arguments
-- with respect to the second argument.
-- Equivalent to `tupleDerivativeOverY` up to the curring.
twoArgsDerivativeOverY ::
  (Additive (CT a1), AutoDifferentiableValue b) =>
  (RevDiff' a1 a0 -> RevDiff' a1 a1 -> b) ->
  a0 ->
  a1 ->
  DerivativeValue b
twoArgsDerivativeOverY :: forall a1 b a0.
(Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivativeOverY RevDiff' a1 a0 -> RevDiff' a1 a1 -> b
f = (RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b)
-> (a0 -> RevDiff' a1 a1 -> b) -> a0 -> a1 -> DerivativeValue b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a1 a0 -> RevDiff' a1 a1 -> b
f (RevDiff' a1 a0 -> RevDiff' a1 a1 -> b)
-> (a0 -> RevDiff' a1 a0) -> a0 -> RevDiff' a1 a1 -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff

-- | Differentiable operator for functions with tuple value and any supported by
-- `AutoDifferentiableArgument` argument type.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
--
-- >>> :{
--  f :: TrigField a => a -> (a, a)
--  f x = (sin x, cos x)
-- :}
--
-- >>> f (variable "x")
-- (sin(x),cos(x))
--
-- >>> :{
--  f' :: SE -> (SE, SE)
--  f' = simplify . tupleValDerivative f
-- :}
--
-- >>> f' (variable "x")
-- (cos(x),-(sin(x)))
tupleValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c0,
    Multiplicative c1,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1)) ->
  DerivativeArg a ->
  (b0, b1)
tupleValDerivative :: forall a c0 c1 b0 d0 b1 d1.
(AutoDifferentiableArgument a, Multiplicative c0,
 Multiplicative c1, DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1))
-> DerivativeArg a -> (b0, b1)
tupleValDerivative = ((RevDiff b0 c0 d0, RevDiff b1 c1 d1) -> (b0, b1))
-> (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1))
-> DerivativeArg a
-> (b0, b1)
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative (RevDiff b0 c0 d0, RevDiff b1 c1 d1) -> (b0, b1)
forall b0 b1 a0 c0 a1 c1.
(Multiplicative b0, Multiplicative b1) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1) -> (a0, a1)
tupleVal

-- boxedVectorValDerivative ::
--   ( AutoDifferentiableArgument a,
--     Multiplicative c,
--     DerivativeCoarg a ~ CT (DerivativeArg a),
--     DerivativeRoot a ~ CT (DerivativeArg a)
--   ) =>
--   (a -> BoxedVector n (RevDiff b c d)) ->
--   DerivativeArg a ->
--   BoxedVector n b
-- boxedVectorValDerivative = customValDerivative boxedVectorVal

-- Triple

-- | Differentiable operator for functions over triple arguments
-- with respect to the first argument.
tripleDerivativeOverX ::
  (AutoDifferentiableValue b, Additive (CT a0)) =>
  ((RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b) ->
  (a0, a1, a2) ->
  DerivativeValue b
tripleDerivativeOverX :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverX (RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
  (RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
    (\RevDiff' a0 a0
x -> (RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b
f (RevDiff' a0 a0
x, a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1, a2 -> RevDiff' a0 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
    a0
x0

-- | Differentiable operator for functions over triple arguments
-- with respect to the second argument.
tripleDerivativeOverY ::
  (AutoDifferentiableValue b, Additive (CT a1)) =>
  ((RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b) ->
  (a0, a1, a2) ->
  DerivativeValue b
tripleDerivativeOverY :: forall b a1 a0 a2.
(AutoDifferentiableValue b, Additive (CT a1)) =>
((RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverY (RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
  (RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
    (\RevDiff' a1 a1
x -> (RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, RevDiff' a1 a1
x, a2 -> RevDiff' a1 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
    a1
x1

-- | Differentiable operator for functions over triple arguments
-- with respect to the third argument.
tripleDerivativeOverZ ::
  (AutoDifferentiableValue b, Additive (CT a2)) =>
  ((RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b) ->
  (a0, a1, a2) ->
  DerivativeValue b
tripleDerivativeOverZ :: forall b a2 a0 a1.
(AutoDifferentiableValue b, Additive (CT a2)) =>
((RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverZ (RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
  (RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
    (\RevDiff' a2 a2
x -> (RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b
f (a0 -> RevDiff' a2 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, a1 -> RevDiff' a2 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1, RevDiff' a2 a2
x))
    a2
x2

-- | Transforms three `RevDiff` instances into a `RevDiff` instances of a triple.
-- The inverese operation is 'tripleArg'.
threeArgsToTriple ::
  (Additive a) =>
  RevDiff a b0 c0 ->
  RevDiff a b1 c1 ->
  RevDiff a b2 c2 ->
  RevDiff a (b0, b1, b2) (c0, c1, c2)
threeArgsToTriple :: forall a b0 c0 b1 c1 b2 c2.
Additive a =>
RevDiff a b0 c0
-> RevDiff a b1 c1
-> RevDiff a b2 c2
-> RevDiff a (b0, b1, b2) (c0, c1, c2)
threeArgsToTriple (MkRevDiff c0
x0 b0 -> a
bpc0) (MkRevDiff c1
x1 b1 -> a
bpc1) (MkRevDiff c2
x2 b2 -> a
bpc2) =
  (c0, c1, c2)
-> ((b0, b1, b2) -> a) -> RevDiff a (b0, b1, b2) (c0, c1, c2)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c0
x0, c1
x1, c2
x2) (\(b0
cy0, b1
cy1, b2
cy2) -> b0 -> a
bpc0 b0
cy0 a -> a -> a
forall a. Additive a => a -> a -> a
+ b1 -> a
bpc1 b1
cy1 a -> a -> a
forall a. Additive a => a -> a -> a
+ b2 -> a
bpc2 b2
cy2)

-- | Triple argument descriptor for differentiable functions.
-- Transforms a `RevDiff` instances of a triple into a triple of `RevDiff` instances.
-- This allows applying differentiable operations to each element of the triple.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   f :: Multiplicative a => (a, a, a) -> a
--   f (x, y, z) = x * y * z
-- :}
--
-- >>> :{
--   f' :: (Distributive a, CT a ~ a) => (a, a, a) -> (a, a, a)
--   f' = customArgDerivative tripleArg f
-- :}
--
-- >>> simplify $ f' (variable "x", variable "y", variable "z")
-- (y*z,x*z,x*y)
tripleArg ::
  (Additive b0, Additive b1, Additive b2) =>
  RevDiff a (b0, b1, b2) (c0, c1, c2) ->
  (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg :: forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg (MkRevDiff (c0
x0, c1
x1, c2
x2) (b0, b1, b2) -> a
bpc) =
  ( c0 -> (b0 -> a) -> RevDiff a b0 c0
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c0
x0 (\b0
cx -> (b0, b1, b2) -> a
bpc (b0
cx, b1
forall a. Additive a => a
zero, b2
forall a. Additive a => a
zero)),
    c1 -> (b1 -> a) -> RevDiff a b1 c1
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c1
x1 (\b1
cy -> (b0, b1, b2) -> a
bpc (b0
forall a. Additive a => a
zero, b1
cy, b2
forall a. Additive a => a
zero)),
    c2 -> (b2 -> a) -> RevDiff a b2 c2
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c2
x2 (\b2
cz -> (b0, b1, b2) -> a
bpc (b0
forall a. Additive a => a
zero, b1
forall a. Additive a => a
zero, b2
cz))
  )

-- | Triple argument builder.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:sophisticated-45-argument-45-function-45-how-45-it-45-works)
-- for details and examples for the tuple.
mkTripleArg ::
  (Additive b0, Additive b1, Additive b2) =>
  RevDiffArg a b0 c0 d0 ->
  RevDiffArg a b1 c1 d1 ->
  RevDiffArg a b2 c2 d2 ->
  RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg :: forall b0 b1 b2 a c0 d0 c1 d1 c2 d2.
(Additive b0, Additive b1, Additive b2) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 RevDiffArg a b2 c2 d2
f2 = RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
-> (d0, d1, d2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
cross3 RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 RevDiffArg a b2 c2 d2
f2 ((RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
 -> (d0, d1, d2))
-> (RevDiff a (b0, b1, b2) (c0, c1, c2)
    -> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2))
-> RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (d0, d1, d2)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg

-- | Triple instance for `AutoDifferentiableArgument` typeclass.
-- It makes it possible to differntiate triple argument funcitons.
instance
  ( AutoDifferentiableArgument a0,
    AutoDifferentiableArgument a1,
    AutoDifferentiableArgument a2,
    DerivativeRoot a0 ~ DerivativeRoot a1,
    DerivativeRoot a0 ~ DerivativeRoot a2
  ) =>
  AutoDifferentiableArgument (a0, a1, a2)
  where
  type DerivativeRoot (a0, a1, a2) = DerivativeRoot a0
  type DerivativeCoarg (a0, a1, a2) = (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
  type DerivativeArg (a0, a1, a2) = (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
  autoArg :: RevDiff (DerivativeRoot a0) (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2) (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2) -> (a0, a1, a2)
  autoArg :: RevDiff
  (DerivativeRoot a0)
  (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
  (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
-> (a0, a1, a2)
autoArg = RevDiffArg
  (DerivativeRoot a2) (DerivativeCoarg a0) (DerivativeArg a0) a0
-> RevDiffArg
     (DerivativeRoot a2) (DerivativeCoarg a1) (DerivativeArg a1) a1
-> RevDiffArg
     (DerivativeRoot a2) (DerivativeCoarg a2) (DerivativeArg a2) a2
-> RevDiffArg
     (DerivativeRoot a2)
     (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
     (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
     (a0, a1, a2)
forall b0 b1 b2 a c0 d0 c1 d1 c2 d2.
(Additive b0, Additive b1, Additive b2) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg RevDiff (DerivativeRoot a0) (DerivativeCoarg a0) (DerivativeArg a0)
-> a0
RevDiffArg
  (DerivativeRoot a2) (DerivativeCoarg a0) (DerivativeArg a0) a0
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiff (DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1)
-> a1
RevDiffArg
  (DerivativeRoot a2) (DerivativeCoarg a1) (DerivativeArg a1) a1
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiffArg
  (DerivativeRoot a2) (DerivativeCoarg a2) (DerivativeArg a2) a2
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- | Triple differentiable value builder
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples for tuple.
mkTripleVal :: (a0 -> b0) -> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal :: forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal = (a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
cross3

-- | Triple differentiable value descriptor.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
tripleVal ::
  (Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
  (RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2) ->
  (a0, a1, a2)
tripleVal :: forall b0 b1 b2 a0 c0 a1 c1 a2 c2.
(Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
tripleVal = (RevDiff a0 b0 c0 -> a0)
-> (RevDiff a1 b1 c1 -> a1)
-> (RevDiff a2 b2 c2 -> a2)
-> (RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal RevDiff a0 b0 c0 -> a0
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a1 b1 c1 -> a1
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a2 b2 c2 -> a2
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | Triple instance for `AutoDifferentiableValue` typeclass.
instance
  ( AutoDifferentiableValue a0,
    AutoDifferentiableValue a1,
    AutoDifferentiableValue a2
  ) =>
  AutoDifferentiableValue (a0, a1, a2)
  where
  type DerivativeValue (a0, a1, a2) = (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
  autoVal :: (a0, a1, a2) -> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
  autoVal :: (a0, a1, a2)
-> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
autoVal = (a0 -> DerivativeValue a0)
-> (a1 -> DerivativeValue a1)
-> (a2 -> DerivativeValue a2)
-> (a0, a1, a2)
-> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal a0 -> DerivativeValue a0
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a1 -> DerivativeValue a1
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a2 -> DerivativeValue a2
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Differentiable operator for functions with triple argument
-- and any supported by `AutoDifferentiableValue` value type.
-- The output is a triple of corresponding partial derivatives.
-- This function is equivalent to 'threeArgsDerivative' up to the curring.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.SimpleExpr.Utils.Algebra (AlgebraicPower, square, MultiplicativeAction)
-- >>> import Debug.DiffExpr (SymbolicFunc)
--
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   z = variable "z"
--   norm :: (AlgebraicPower Integer a, Additive a) => (a, a, a) -> a
--   norm (x, y, z) = square x + square y + square z
-- :}
--
-- >>> norm (x, y, z)
-- ((x^2)+(y^2))+(z^2)
--
-- >>> :{
--  norm' :: (SE, SE, SE) -> (SE, SE, SE)
--  norm' = simplify . tripleArgDerivative norm
-- :}
--
-- >>> simplify $ norm' (x, y, z)
-- (2*x,2*y,2*z)
--
-- >>> :{
--  norm'' :: (SE, SE, SE) -> ((SE, SE, SE), (SE, SE, SE), (SE, SE, SE))
--  norm'' = simplify . tripleArgDerivative (tripleArgDerivative norm)
-- :}
--
-- >>> norm'' (x, y, z)
-- ((2,0,0),(0,2,0),(0,0,2))
tripleArgDerivative ::
  ( Additive (CT a0),
    Additive (CT a1),
    Additive (CT a2),
    AutoDifferentiableValue b
  ) =>
  ( ( RevDiff' (a0, a1, a2) a0,
      RevDiff' (a0, a1, a2) a1,
      RevDiff' (a0, a1, a2) a2
    ) ->
    b
  ) ->
  (a0, a1, a2) ->
  DerivativeValue b
tripleArgDerivative :: forall a0 a1 a2 b.
(Additive (CT a0), Additive (CT a1), Additive (CT a2),
 AutoDifferentiableValue b) =>
((RevDiff' (a0, a1, a2) a0, RevDiff' (a0, a1, a2) a1,
  RevDiff' (a0, a1, a2) a2)
 -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleArgDerivative = (RevDiff
   (Dual (Tangent (a0, a1, a2)))
   (Dual (Tangent (a0, a1, a2)))
   (a0, a1, a2)
 -> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
     RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
     RevDiff (CT a0, CT a1, CT a2) (CT a2) a2))
-> ((RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
     RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
     RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
    -> b)
-> (a0, a1, a2)
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
    RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
    RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
RevDiff
  (Dual (Tangent (a0, a1, a2)))
  (Dual (Tangent (a0, a1, a2)))
  (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
    RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
    RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg

-- | Differentiable operator for functions over three argument.
-- and any supported by `AutoDifferentiableValue` value type.
-- The output is a triple of corresponding partial derivatives.
-- This function is equivalent to 'tripleArgDerivative' up to the curring.
threeArgsDerivative ::
  ( AutoDifferentiableValue b,
    Additive (CT a0),
    Additive (CT a1),
    Additive (CT a2)
  ) =>
  ( RevDiff' (a0, a1, a2) a0 ->
    RevDiff' (a0, a1, a2) a1 ->
    RevDiff' (a0, a1, a2) a2 ->
    b
  ) ->
  a0 ->
  a1 ->
  a2 ->
  DerivativeValue b
threeArgsDerivative :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0), Additive (CT a1),
 Additive (CT a2)) =>
(RevDiff' (a0, a1, a2) a0
 -> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
threeArgsDerivative RevDiff' (a0, a1, a2) a0
-> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b
f = ((a0, a1, a2) -> DerivativeValue b)
-> a0 -> a1 -> a2 -> DerivativeValue b
forall a b c d. ((a, b, c) -> d) -> a -> b -> c -> d
curry3 ((RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
 -> (a0, a1, a2) -> DerivativeValue b)
-> (RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
-> (a0, a1, a2)
-> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0
 -> RevDiff (CT a0, CT a1, CT a2) (CT a1) a1
 -> RevDiff (CT a0, CT a1, CT a2) (CT a2) a2
 -> b)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
    RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
    RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
-> b
forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 RevDiff (CT a0, CT a1, CT a2) (CT a0) a0
-> RevDiff (CT a0, CT a1, CT a2) (CT a1) a1
-> RevDiff (CT a0, CT a1, CT a2) (CT a2) a2
-> b
RevDiff' (a0, a1, a2) a0
-> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b
f ((RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
  RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
  RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
 -> b)
-> (RevDiff
      (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
    -> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
        RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
        RevDiff (CT a0, CT a1, CT a2) (CT a2) a2))
-> RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
    RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
    RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg)

-- | Differentiable operator for functions over three argument
-- with respect to the first argument.
-- and any supported by `AutoDifferentiableValue` value type.
derivative3ArgsOverX ::
  (AutoDifferentiableValue b, Additive (CT a0)) =>
  (RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b) ->
  a0 ->
  a1 ->
  a2 ->
  DerivativeValue b
derivative3ArgsOverX :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0)) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverX RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b
f a0
x0 a1
x1 a2
x2 =
  (RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
    (\RevDiff' a0 a0
x0' -> RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b
f RevDiff' a0 a0
x0' (a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1) (a2 -> RevDiff' a0 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
    a0
x0

-- | Differentiable operator for functions over three argument
-- with respect to the second argument.
-- and any supported by `AutoDifferentiableValue` value type.
derivative3ArgsOverY ::
  (AutoDifferentiableValue b, Additive (CT a1)) =>
  (RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b) ->
  a0 ->
  a1 ->
  a2 ->
  DerivativeValue b
derivative3ArgsOverY :: forall b a1 a0 a2.
(AutoDifferentiableValue b, Additive (CT a1)) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverY RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b
f a0
x0 a1
x1 a2
x2 =
  (RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
    (\RevDiff' a1 a1
x1' -> RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0) RevDiff' a1 a1
x1' (a2 -> RevDiff' a1 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
    a1
x1

-- | Differentiable operator for functions over three argument
-- with respect to the third argument.
-- and any supported by `AutoDifferentiableValue` value type.
derivative3ArgsOverZ ::
  (AutoDifferentiableValue b, Additive (CT a2)) =>
  (RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b) ->
  a0 ->
  a1 ->
  a2 ->
  DerivativeValue b
derivative3ArgsOverZ :: forall b a2 a0 a1.
(AutoDifferentiableValue b, Additive (CT a2)) =>
(RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverZ RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b
f a0
x0 a1
x1 =
  (RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b)
-> (RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b
f (a0 -> RevDiff' a2 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0) (a1 -> RevDiff' a2 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)

-- | Differentiable operator for functions with tuple value and any supported by
-- `AutoDifferentiableArgument` argument type.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
--
-- >>> :{
--  f :: (Multiplicative a, IntegerPower a) => a -> (a, a, a)
--  f x = (one, x^1, x^2)
-- :}
--
-- >>> f (variable "x")
-- (1,x^1,x^2)
--
-- >>> :{
--  f' :: SE -> (SE, SE, SE)
--  f' = simplify . tripleValDerivative f
-- :}
--
-- >>> f' (variable "x")
-- (0,1,2*x)
tripleValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c0,
    Multiplicative c1,
    Multiplicative c2,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)) ->
  DerivativeArg a ->
  (b0, b1, b2)
tripleValDerivative :: forall a c0 c1 c2 b0 d0 b1 d1 b2 d2.
(AutoDifferentiableArgument a, Multiplicative c0,
 Multiplicative c1, Multiplicative c2,
 DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2))
-> DerivativeArg a -> (b0, b1, b2)
tripleValDerivative = ((RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)
 -> (b0, b1, b2))
-> (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2))
-> DerivativeArg a
-> (b0, b1, b2)
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)
-> (b0, b1, b2)
forall b0 b1 b2 a0 c0 a1 c1 a2 c2.
(Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
tripleVal

-- BoxedVector

-- | `BoxedVector` differentiable value builder
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
mkBoxedVectorVal :: (a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal :: forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal = (a -> b) -> Vector Vector n a -> Vector Vector n b
forall a b. (a -> b) -> Vector Vector n a -> Vector Vector n b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

-- | `BoxedVector` instance for `AutoDifferentiableValue` typeclass.
instance
  (AutoDifferentiableValue a) =>
  AutoDifferentiableValue (BoxedVector n a)
  where
  type DerivativeValue (BoxedVector n a) = BoxedVector n (DerivativeValue a)
  autoVal :: BoxedVector n a -> BoxedVector n (DerivativeValue a)
  autoVal :: BoxedVector n a -> BoxedVector n (DerivativeValue a)
autoVal = (a -> DerivativeValue a)
-> BoxedVector n a -> BoxedVector n (DerivativeValue a)
forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Boxed array differentiable value descriptor.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
--
-- >>> :{
--   v :: SymbolicFunc a => a -> BoxedVector 3 a
--   v t = DVGS.fromTuple (
--      unarySymbolicFunc "v_x" t,
--      unarySymbolicFunc "v_y" t,
--      unarySymbolicFunc "v_z" t
--    )
-- :}
--
-- >>> t = variable "t"
-- >>> v t
-- Vector [v_x(t),v_y(t),v_z(t)]
--
-- >>> v' = simplify . customValDerivative boxedVectorVal v :: SE -> BoxedVector 3 SE
-- >>> v' t
-- Vector [v_x'(t),v_y'(t),v_z'(t)]
boxedVectorVal ::
  (Multiplicative b) =>
  BoxedVector n (RevDiff a b c) ->
  BoxedVector n a
boxedVectorVal :: forall b (n :: Natural) a c.
Multiplicative b =>
BoxedVector n (RevDiff a b c) -> BoxedVector n a
boxedVectorVal = (RevDiff a b c -> a)
-> BoxedVector n (RevDiff a b c) -> BoxedVector n a
forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | Differentiable operator for functions with `BoxedVector` argument
-- and any supported by `AutoDifferentiableValue` value type.
-- The output is a `BoxedVector` instamce of corresponding drivatives.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
--
-- >>> :{
--   v :: SymbolicFunc a => a -> BoxedVector 3 a
--   v t = DVGS.fromTuple (
--      unarySymbolicFunc "v_x" t,
--      unarySymbolicFunc "v_y" t,
--      unarySymbolicFunc "v_z" t
--    )
-- :}
--
-- >>> t = variable "t"
-- >>> v t
-- Vector [v_x(t),v_y(t),v_z(t)]
--
-- >>> v' = simplify . boxedVectorValDerivative v :: SE -> BoxedVector 3 SE
-- >>> v' t
-- Vector [v_x'(t),v_y'(t),v_z'(t)]
boxedVectorValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> BoxedVector n (RevDiff b c d)) ->
  DerivativeArg a ->
  BoxedVector n b
boxedVectorValDerivative :: forall a c (n :: Natural) b d.
(AutoDifferentiableArgument a, Multiplicative c,
 DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> BoxedVector n (RevDiff b c d))
-> DerivativeArg a -> BoxedVector n b
boxedVectorValDerivative = (BoxedVector n (RevDiff b c d) -> BoxedVector n b)
-> (a -> BoxedVector n (RevDiff b c d))
-> DerivativeArg a
-> BoxedVector n b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative BoxedVector n (RevDiff b c d) -> BoxedVector n b
forall b (n :: Natural) a c.
Multiplicative b =>
BoxedVector n (RevDiff a b c) -> BoxedVector n a
boxedVectorVal

-- | Boxed vector argument descriptor for differentiable functions.
-- Transforms a `RevDiff` instances of a boxed vector into a boxed vectror
-- of `RevDiff` instances.
-- This allows applying differentiable operations to each element of the boxed Vector.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (SymbolicFunc, unarySymbolicFunc)
--
-- >>> :{
--   f :: Additive a => BoxedVector 3 a -> a
--   f = boxedVectorSum
-- :}
--
-- >>> :{
--   f' :: (Distributive a, CT a ~ a) => BoxedVector 3 a -> BoxedVector 3 a
--   f' = customArgDerivative boxedVectorArg f
-- :}
--
-- >>> simplify $ f' (DVGS.fromTuple (variable "x", variable "y", variable "z"))
-- Vector [1,1,1]
boxedVectorArg ::
  (Additive b, KnownNat n) =>
  RevDiff a (BoxedVector n b) (BoxedVector n c) ->
  BoxedVector n (RevDiff a b c)
boxedVectorArg :: forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg (MkRevDiff BoxedVector n c
array BoxedVector n b -> a
bpc) = (Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c)
forall (v :: * -> *) (n :: Natural) a.
(KnownNat n, Vector v a) =>
(Finite n -> a) -> Vector v n a
DVGS.generate ((Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c))
-> (Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c)
forall a b. (a -> b) -> a -> b
$ \Finite n
k ->
  c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (BoxedVector n c -> Finite n -> c
forall (v :: * -> *) (n :: Natural) a.
Vector v a =>
Vector v n a -> Finite n -> a
DVGS.index BoxedVector n c
array Finite n
k) (BoxedVector n b -> a
bpc (BoxedVector n b -> a) -> (b -> BoxedVector n b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Finite n -> b -> b -> BoxedVector n b
forall (v :: * -> *) a (n :: Natural).
(Vector v a, KnownNat n) =>
Finite n -> a -> a -> Vector v n a
boxedVectorBasis Finite n
k b
forall a. Additive a => a
zero)

-- unpackBoxedVector ::
--   (Additive a, KnownNat n) =>
--   BoxedVector n (RevDiff a b c) ->
--   RevDiff a (BoxedVector n b) (BoxedVector n c)
-- unpackBoxedVector array =
--   MkRevDiff'
--     (fmap value array)
--     (boxedVectorSum . (fmap backprop array <*>))

-- | `BoxedVector` argument descriptor builder.
mkBoxedVectorArg ::
  (Additive b, KnownNat n) =>
  RevDiffArg a b c d ->
  RevDiffArg a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg :: forall b (n :: Natural) a c d.
(Additive b, KnownNat n) =>
RevDiffArg a b c d
-> RevDiffArg
     a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg RevDiffArg a b c d
f = RevDiffArg a b c d
-> Vector Vector n (RevDiff a b c) -> Vector Vector n d
forall a b. (a -> b) -> Vector Vector n a -> Vector Vector n b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiffArg a b c d
f (Vector Vector n (RevDiff a b c) -> Vector Vector n d)
-> (RevDiff a (BoxedVector n b) (BoxedVector n c)
    -> Vector Vector n (RevDiff a b c))
-> RevDiff a (BoxedVector n b) (BoxedVector n c)
-> Vector Vector n d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (BoxedVector n b) (BoxedVector n c)
-> Vector Vector n (RevDiff a b c)
forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg

-- | `BoxedVector` instance for `AutoDifferentiableArgument` typeclass.
instance
  ( AutoDifferentiableArgument a,
    KnownNat n
  ) =>
  AutoDifferentiableArgument (BoxedVector n a)
  where
  type DerivativeRoot (BoxedVector n a) = DerivativeRoot a
  type DerivativeCoarg (BoxedVector n a) = BoxedVector n (DerivativeCoarg a)
  type DerivativeArg (BoxedVector n a) = BoxedVector n (DerivativeArg a)
  autoArg :: RevDiff (DerivativeRoot a) (BoxedVector n (DerivativeCoarg a)) (BoxedVector n (DerivativeArg a)) -> BoxedVector n a
  autoArg :: RevDiff
  (DerivativeRoot a)
  (BoxedVector n (DerivativeCoarg a))
  (BoxedVector n (DerivativeArg a))
-> BoxedVector n a
autoArg = RevDiffArg
  (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) a
-> RevDiff
     (DerivativeRoot a)
     (BoxedVector n (DerivativeCoarg a))
     (BoxedVector n (DerivativeArg a))
-> BoxedVector n a
forall b (n :: Natural) a c d.
(Additive b, KnownNat n) =>
RevDiffArg a b c d
-> RevDiffArg
     a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg RevDiffArg
  (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- | Differentiable operator for functions with boxed array argument
-- and any supported by `AutoDifferentiableValue` value type.
-- The output is a boxed array of corresponding partial derivatives (i.e. gradient).
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc)
-- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector, boxedVectorSum)
-- >>> import Debug.SimpleExpr.Utils.Algebra (AlgebraicPower, (^))
--
-- >>> :{
--   x = variable "x"
--   y = variable "y"
--   z = variable "z"
--   r = DVGS.fromTuple (x, y, z) :: BoxedVector 3 SE
--   norm2 :: (AlgebraicPower Integer a, Additive a) => BoxedVector 3 a -> a
--   norm2 v = boxedVectorSum (v^2)
-- :}
--
-- >>> simplify $ norm2 r
-- ((x^2)+(y^2))+(z^2)
--
-- >>> :{
--  norm2' :: BoxedVector 3 SE -> BoxedVector 3 SE
--  norm2' = simplify . boxedVectorArgDerivative norm2
-- :}
--
-- >>> norm2' r
-- Vector [2*x,2*y,2*z]
--
-- >>> :{
--  norm2'' :: BoxedVector 3 SE -> BoxedVector 3 (BoxedVector 3 SE)
--  norm2'' = simplify . boxedVectorArgDerivative (boxedVectorArgDerivative norm2)
-- :}
--
-- >>> norm2'' r
-- Vector [Vector [2,0,0],Vector [0,2,0],Vector [0,0,2]]
boxedVectorArgDerivative ::
  (KnownNat n, AutoDifferentiableValue b, Additive (CT a)) =>
  (BoxedVector n (RevDiff' (BoxedVector n a) a) -> b) ->
  BoxedVector n a ->
  DerivativeValue b
boxedVectorArgDerivative :: forall (n :: Natural) b a.
(KnownNat n, AutoDifferentiableValue b, Additive (CT a)) =>
(BoxedVector n (RevDiff' (BoxedVector n a) a) -> b)
-> BoxedVector n a -> DerivativeValue b
boxedVectorArgDerivative = (RevDiff
   (Dual (Tangent (Vector Vector n a)))
   (Dual (Tangent (Vector Vector n a)))
   (Vector Vector n a)
 -> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a))
-> (BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a) -> b)
-> Vector Vector n a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
  (Vector Vector n (CT a))
  (Vector Vector n (CT a))
  (Vector Vector n a)
-> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a)
RevDiff
  (Dual (Tangent (Vector Vector n a)))
  (Dual (Tangent (Vector Vector n a)))
  (Vector Vector n a)
-> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a)
forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg

-- instance (HasSum (BoxedVector n c) d, KnownNat n) =>
--   HasSum (RevDiff a (BoxedVector n b) (BoxedVector n c)) (RevDiff a b d) where
--     sum (MkRevDiff vec bp) = MkRevDiff' (sum vec) (bp . DVGS.replicate)

-- ** Stream

-- | Stream differentiable value builder
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
mkStreamVal :: (a -> b) -> Stream a -> Stream b
mkStreamVal :: forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal = (a -> b) -> Stream a -> Stream b
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

-- | Stream value structure for differentiable functions.
--
-- ==== __Examples__
--
-- >>> import GHC.Base ((<>))
-- >>> import Data.Stream (Stream, fromList, take)
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
--
-- >>> :{
--   s :: SymbolicFunc a => a -> Stream a
--   s t = fromList [unarySymbolicFunc ("s_" <> show n) t | n <- [0..]]
-- :}
--
-- >>> t = variable "t"
-- >>> take 5 (s t)
-- [s_0(t),s_1(t),s_2(t),s_3(t),s_4(t)]
--
-- >>> :{
--   s' :: SE -> Stream SE
--   s' = simplify . customValDerivative streamVal s
-- :}
--
-- >>> take 5 (s' t)
-- [s_0'(t),s_1'(t),s_2'(t),s_3'(t),s_4'(t)]
streamVal ::
  (Multiplicative b) =>
  Stream (RevDiff a b c) ->
  Stream a
streamVal :: forall b a c.
Multiplicative b =>
Stream (RevDiff a b c) -> Stream a
streamVal = (RevDiff a b c -> a) -> Stream (RevDiff a b c) -> Stream a
forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | `Stream` instance for `AutoDifferentiableValue` typeclass.
instance
  (AutoDifferentiableValue a) =>
  AutoDifferentiableValue (Stream a)
  where
  type DerivativeValue (Stream a) = Stream (DerivativeValue a)
  autoVal :: Stream a -> Stream (DerivativeValue a)
  autoVal :: Stream a -> Stream (DerivativeValue a)
autoVal = (a -> DerivativeValue a) -> Stream a -> Stream (DerivativeValue a)
forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Derivative operator for a function from any supported argument type to a Stream.
streamValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> Stream (RevDiff b c d)) ->
  DerivativeArg a ->
  Stream b
streamValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
 DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> Stream (RevDiff b c d)) -> DerivativeArg a -> Stream b
streamValDerivative = (Stream (RevDiff b c d) -> Stream b)
-> (a -> Stream (RevDiff b c d)) -> DerivativeArg a -> Stream b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative Stream (RevDiff b c d) -> Stream b
forall b a c.
Multiplicative b =>
Stream (RevDiff a b c) -> Stream a
streamVal

-- | Stream argument descriptor for differentiable functions.
-- Transforms a `RevDiff` instances of a stream into a stream of `RevDiff` instances.
-- This allows applying differentiable operations to each element of the Stream.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import GHC.Base ((<>))
--
-- >>> :{
--   f :: Additive a => Stream a -> a
--   f = NumHask.sum . Data.Stream.take 4 :: Additive a => Data.Stream.Stream a -> a
-- :}
--
-- >>> :{
--   f' :: (Distributive a, CT a ~ a) => Stream a -> FiniteSupportStream a
--   f' = customArgDerivative streamArg f
-- :}
--
-- >>> s = Data.Stream.fromList [variable ("s_" <> show n) | n <- [0 :: Int ..]] :: Data.Stream.Stream SE
-- >>> simplify $ f' s
-- [1,1,1,1,0,0,0,...
streamArg ::
  (Additive b) =>
  RevDiff a (FiniteSupportStream b) (Stream c) ->
  Stream (RevDiff a b c)
streamArg :: forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg (MkRevDiff Stream c
x FiniteSupportStream b -> a
bpc) =
  RevDiff a b c -> Stream (RevDiff a b c) -> Stream (RevDiff a b c)
forall a. a -> Stream a -> Stream a
DS.Cons
    (c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x_head b -> a
bpc_head)
    (RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg (Stream c
-> (FiniteSupportStream b -> a)
-> RevDiff a (FiniteSupportStream b) (Stream c)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff Stream c
x_tail FiniteSupportStream b -> a
bpc_tail))
  where
    x_head :: c
x_head = Stream c -> c
forall a. Stream a -> a
DS.head Stream c
x
    x_tail :: Stream c
x_tail = Stream c -> Stream c
forall a. Stream a -> Stream a
DS.tail Stream c
x
    bpc_head :: b -> a
bpc_head = FiniteSupportStream b -> a
bpc (FiniteSupportStream b -> a)
-> (b -> FiniteSupportStream b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> FiniteSupportStream b
forall a. a -> FiniteSupportStream a
singleton
    bpc_tail :: FiniteSupportStream b -> a
bpc_tail = FiniteSupportStream b -> a
bpc (FiniteSupportStream b -> a)
-> (FiniteSupportStream b -> FiniteSupportStream b)
-> FiniteSupportStream b
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> FiniteSupportStream b -> FiniteSupportStream b
forall a. a -> FiniteSupportStream a -> FiniteSupportStream a
cons b
forall a. Additive a => a
zero

-- | Stream argument builder.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:sophisticated-45-argument-45-function-45-how-45-it-45-works)
-- for details and examples for the tuple.
mkStreamArg ::
  (Additive b) =>
  (RevDiff a b c -> d) ->
  RevDiff a (FiniteSupportStream b) (Stream c) ->
  Stream d
mkStreamArg :: forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (FiniteSupportStream b) (Stream c) -> Stream d
mkStreamArg RevDiff a b c -> d
f = (RevDiff a b c -> d) -> Stream (RevDiff a b c) -> Stream d
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiff a b c -> d
f (Stream (RevDiff a b c) -> Stream d)
-> (RevDiff a (FiniteSupportStream b) (Stream c)
    -> Stream (RevDiff a b c))
-> RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg

-- | `Stream` instance for `AutoDifferentiableArgument` typeclass.
instance
  (AutoDifferentiableArgument a) =>
  AutoDifferentiableArgument (Stream a)
  where
  type DerivativeRoot (Stream a) = DerivativeRoot a
  type DerivativeCoarg (Stream a) = FiniteSupportStream (DerivativeCoarg a)
  type DerivativeArg (Stream a) = Stream (DerivativeArg a)
  autoArg :: RevDiff (DerivativeRoot a) (FiniteSupportStream (DerivativeCoarg a)) (Stream (DerivativeArg a)) -> Stream a
  autoArg :: RevDiff
  (DerivativeRoot a)
  (FiniteSupportStream (DerivativeCoarg a))
  (Stream (DerivativeArg a))
-> Stream a
autoArg = (RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
 -> a)
-> RevDiff
     (DerivativeRoot a)
     (FiniteSupportStream (DerivativeCoarg a))
     (Stream (DerivativeArg a))
-> Stream a
forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (FiniteSupportStream b) (Stream c) -> Stream d
mkStreamArg RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- | Differentiable operator for functions with `Stream` argument
-- and any supported by `AutoDifferentiableValue` value type.
-- The output is a boxed array of corresponding partial derivatives (i.e. gradient).
--
-- ==== __Examples__
--
-- >>> import GHC.Base ((<>))
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc)
-- >>> import Data.Stream (Stream, fromList, take)
--
-- >>> s = fromList [variable ("s_" <> show n) | n <- [0 :: Int ..]] :: Stream SE
--
-- >>> take4Sum = NumHask.sum . take 4 :: Additive a => Stream a -> a
-- >>> simplify $ take4Sum s :: SE
-- s_0+(s_1+(s_2+s_3))
--
-- >>> :{
--  take4Sum' :: (Distributive a, CT a ~ a) =>
--    Stream a -> FiniteSupportStream (CT a)
--  take4Sum' = streamArgDerivative take4Sum
-- :}
--
-- >>> simplify $ take4Sum' s
-- [1,1,1,1,0,0,0,...
--
-- >>> :{
--  take4Sum'' :: (Distributive a, CT a ~ a) =>
--    Stream a -> FiniteSupportStream (FiniteSupportStream (CT a))
--  take4Sum'' = streamArgDerivative (streamArgDerivative take4Sum)
-- :}
--
-- >>> simplify $ take4Sum'' s
-- [[0,0,0,...,[0,0,0,...,[0,0,0,...,...
streamArgDerivative ::
  (AutoDifferentiableValue b, Additive (CT a)) =>
  (Stream (RevDiff' (Stream a) a) -> b) ->
  Stream a ->
  DerivativeValue b
streamArgDerivative :: forall b a.
(AutoDifferentiableValue b, Additive (CT a)) =>
(Stream (RevDiff' (Stream a) a) -> b)
-> Stream a -> DerivativeValue b
streamArgDerivative = (RevDiff
   (Dual (Tangent (Stream a))) (Dual (Tangent (Stream a))) (Stream a)
 -> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a))
-> (Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a) -> b)
-> Stream a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
  (FiniteSupportStream (CT a))
  (FiniteSupportStream (CT a))
  (Stream a)
-> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a)
RevDiff
  (Dual (Tangent (Stream a))) (Dual (Tangent (Stream a))) (Stream a)
-> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg

-- FiniteSupportStream

-- | Finite support stream differentiable value builder
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
-- It is expected that the argument function is linear or at least maps zero to zero.
mkFiniteSupportStreamVal :: (a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal :: forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal = (a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
unsafeMap

-- | Finite support stream value structure for differentiable functions.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
-- >>> import Data.FiniteSupportStream (unsafeFromList, FiniteSupportStream)
--
-- >>> :{
--  fss :: (Multiplicative a, IntegerPower a) =>
--    a -> FiniteSupportStream a
--  fss t = unsafeFromList [t^3, t^2, t, one]
-- :}
--
-- >>> t = variable "t"
-- >>> fss t
-- [t^3,t^2,t,1,0,0,0,...
--
-- >>> :{
--   fss' :: SE -> FiniteSupportStream SE
--   fss' = simplify . customValDerivative finiteSupportStreamVal fss
-- :}
--
-- >>> (fss' t)
-- [3*(t^2),2*t,1,0,0,0,...
finiteSupportStreamVal ::
  (Multiplicative b) =>
  FiniteSupportStream (RevDiff a b c) ->
  FiniteSupportStream a
finiteSupportStreamVal :: forall b a c.
Multiplicative b =>
FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
finiteSupportStreamVal = (RevDiff a b c -> a)
-> FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | `FiniteSupportStream` instance for `AutoDifferentiableValue` typeclass.
instance
  (AutoDifferentiableValue a) =>
  AutoDifferentiableValue (FiniteSupportStream a)
  where
  type DerivativeValue (FiniteSupportStream a) = FiniteSupportStream (DerivativeValue a)
  autoVal :: FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
  autoVal :: FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
autoVal = (a -> DerivativeValue a)
-> FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Derivative operator for a function from any supported argument type to
-- a `FiniteSupportStream` instance.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Debug.DiffExpr (unarySymbolicFunc, SymbolicFunc)
-- >>> import Data.FiniteSupportStream (unsafeFromList, FiniteSupportStream)
--
-- >>> :{
--  fss :: (Multiplicative a, IntegerPower a) =>
--    a -> FiniteSupportStream a
--  fss t = unsafeFromList [t^3, t^2, t, one]
-- :}
--
-- >>> t = variable "t"
-- >>> fss t
-- [t^3,t^2,t,1,0,0,0,...
--
-- >>> :{
--   fss' :: SE -> FiniteSupportStream SE
--   fss' = simplify . finiteSupportStreamValDerivative fss
-- :}
--
-- >>> fss' t
-- [3*(t^2),2*t,1,0,0,0,...
finiteSupportStreamValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> FiniteSupportStream (RevDiff b c d)) ->
  DerivativeArg a ->
  FiniteSupportStream b
finiteSupportStreamValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
 DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> FiniteSupportStream (RevDiff b c d))
-> DerivativeArg a -> FiniteSupportStream b
finiteSupportStreamValDerivative = (FiniteSupportStream (RevDiff b c d) -> FiniteSupportStream b)
-> (a -> FiniteSupportStream (RevDiff b c d))
-> DerivativeArg a
-> FiniteSupportStream b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative FiniteSupportStream (RevDiff b c d) -> FiniteSupportStream b
forall b a c.
Multiplicative b =>
FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
finiteSupportStreamVal

-- | Finite support stream argument descriptor for differentiable functions.
-- Transforms a `RevDiff` instances of a finite support stream into
-- a finite support stream of `RevDiff` instances.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, SE, simplify)
-- >>> import Data.FiniteSupportStream (unsafeFromList, toVector)
--
-- >>> :{
--   f :: Additive a => FiniteSupportStream a -> a
--   f = NumHask.sum . toVector
-- :}
--
-- >>> f (unsafeFromList [1, 2, 3])
-- 6
--
-- >>> :{
--   f' :: (Distributive a, CT a ~ a) => FiniteSupportStream a -> Stream a
--   f' = customArgDerivative finiteSupportStreamArg f
-- :}
--
-- >>> Data.Stream.take 5 $ f' (unsafeFromList [1, 2, 3])
-- [1,1,1,0,0]
finiteSupportStreamArg ::
  (Additive b) =>
  RevDiff a (Stream b) (FiniteSupportStream c) ->
  FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg :: forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg (MkRevDiff (MkFiniteSupportStream Vector c
arrX) Stream b -> a
bpc) =
  Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c)
forall a. Vector a -> FiniteSupportStream a
MkFiniteSupportStream (Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c))
-> Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c)
forall a b. (a -> b) -> a -> b
$ (Int -> c -> RevDiff a b c) -> Vector c -> Vector (RevDiff a b c)
forall a b. (Int -> a -> b) -> Vector a -> Vector b
DV.imap Int -> c -> RevDiff a b c
f Vector c
arrX
  where
    f :: Int -> c -> RevDiff a b c
f Int
i c
x = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (Stream b -> a
bpc (Stream b -> a) -> (b -> Stream b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> b -> Stream b
forall {t} {a}.
(Eq t, Additive a, Additive t, Num t) =>
t -> a -> Stream a
cStream Int
i)
    cStream :: t -> a -> Stream a
cStream t
i a
cy = t -> Stream a
go t
0
      where
        go :: t -> Stream a
go t
n = a -> Stream a -> Stream a
forall a. a -> Stream a -> Stream a
DS.Cons (if t
i t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
n then a
cy else a
forall a. Additive a => a
zero) (t -> Stream a
go (t
n t -> t -> t
forall a. Additive a => a -> a -> a
+ t
1))

-- cons
--   (MkRevDiff' x_head bpc_head)
--   (finiteSupportStreamArg (MkRevDiff' x_tail bpc_tail))
-- where
--   x_head = trace "taking head" $ head x
--   x_tail = trace "taking tail" $ tail x
--   bpc_head = trace "taking bpc_head" $ bpc . DS.fromList . (: [])
--   bpc_tail = trace "taking bpc_tail" $ bpc . DS.Cons zero

-- | Finite support stream argument descriptor builder.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
-- It is expected that the argument function is linear or at least maps zero to zero.
mkFiniteSupportStreamArg ::
  (Additive b) =>
  (RevDiff a b c -> d) ->
  RevDiff a (Stream b) (FiniteSupportStream c) ->
  FiniteSupportStream d
mkFiniteSupportStreamArg :: forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream d
mkFiniteSupportStreamArg RevDiff a b c -> d
f = (RevDiff a b c -> d)
-> FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream d
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
unsafeMap RevDiff a b c -> d
f (FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream d)
-> (RevDiff a (Stream b) (FiniteSupportStream c)
    -> FiniteSupportStream (RevDiff a b c))
-> RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg

-- | `FiniteSupportStream` instance for `AutoDifferentiableArgument` typeclass.
instance
  (AutoDifferentiableArgument a) =>
  AutoDifferentiableArgument (FiniteSupportStream a)
  where
  type DerivativeRoot (FiniteSupportStream a) = DerivativeRoot a
  type DerivativeCoarg (FiniteSupportStream a) = Stream (DerivativeCoarg a)
  type DerivativeArg (FiniteSupportStream a) = FiniteSupportStream (DerivativeArg a)
  autoArg :: RevDiff (DerivativeRoot a) (Stream (DerivativeCoarg a)) (FiniteSupportStream (DerivativeArg a)) -> FiniteSupportStream a
  autoArg :: RevDiff
  (DerivativeRoot a)
  (Stream (DerivativeCoarg a))
  (FiniteSupportStream (DerivativeArg a))
-> FiniteSupportStream a
autoArg = RevDiff
  (DerivativeRoot a)
  (Stream (DerivativeCoarg a))
  (FiniteSupportStream (DerivativeArg a))
-> FiniteSupportStream a
forall a. HasCallStack => a
undefined

-- | Differentiable operator for functions that take a `FiniteSupportStream` argument
-- and return any value type supported by `AutoDifferentiableValue`.
-- The output is a stream of corresponding partial derivatives,
-- computing the gradient of the function with respect to each stream element.
-- See also
-- ["Tangent and Cotangent Spaces" tutorial section](Numeric-InfBackprop-Tutorial.html#g:how-45-it-45-works-45-tangent-45-space)
-- for the connection beetwen streams and finite support streams.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc)
-- >>> import Data.Stream (Stream, take)
-- >>> import Data.FiniteSupportStream (FiniteSupportStream, unsafeFromList, toVector)
-- >>> import NumHask (sum)
--
-- Define a finite support stream with support length 4 containing 4 symbolic variables.
--
-- >>> s = unsafeFromList [variable "s_0", variable "s_1", variable "s_2", variable "s_3"] :: FiniteSupportStream SE
-- >>> s
-- [s_0,s_1,s_2,s_3,0,0,0,...
--
-- Now we'll define a function that sums all elements of a finite support stream.
--
-- >>> finiteSupportStreamSum = sum . toVector :: Additive a => FiniteSupportStream a -> a
-- >>> simplify $ finiteSupportStreamSum s :: SE
-- s_0+(s_1+(s_2+s_3))
--
-- We compute the gradient
-- of this function.
--
-- >>> :{
--  finiteSupportStreamSum' :: (Distributive a, CT a ~ a) =>
--    FiniteSupportStream a -> Stream (CT a)
--  finiteSupportStreamSum' = finiteSupportStreamArgDerivative finiteSupportStreamSum
-- :}
--
-- Let's compute the gradient at point @s@. It is an infinite stream and we take first 7 elements:
--
-- >>> take 7 $ simplify $ finiteSupportStreamSum' s
-- [1,1,1,1,0,0,0]
--
-- As expected,
-- the gradient is a stream with 1's in the first four positions (corresponding
-- to our four variables and the fixed support length 4) and 0's elsewhere:
--
-- We can compute the second derivative (Hessian matrix) that is stream of streams
-- in our case.
--
-- >>> :{
--  finiteSupportStreamSum'' :: (Distributive a, CT a ~ a) =>
--    FiniteSupportStream a -> Stream (Stream (CT a))
--  finiteSupportStreamSum'' = finiteSupportStreamArgDerivative (finiteSupportStreamArgDerivative finiteSupportStreamSum)
-- :}
--
-- All second derivatives should all be zero. We take first 7 rows and 4 columns of the inifinite Hessian matrix:
--
-- >>> take 7 $ fmap (take 4) $ simplify $ finiteSupportStreamSum'' s
-- [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]]
finiteSupportStreamArgDerivative ::
  (AutoDifferentiableValue b, Additive (CT a)) =>
  (FiniteSupportStream (RevDiff' (FiniteSupportStream a) a) -> b) ->
  FiniteSupportStream a ->
  DerivativeValue b
finiteSupportStreamArgDerivative :: forall b a.
(AutoDifferentiableValue b, Additive (CT a)) =>
(FiniteSupportStream (RevDiff' (FiniteSupportStream a) a) -> b)
-> FiniteSupportStream a -> DerivativeValue b
finiteSupportStreamArgDerivative = (RevDiff
   (Dual (Tangent (FiniteSupportStream a)))
   (Dual (Tangent (FiniteSupportStream a)))
   (FiniteSupportStream a)
 -> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a))
-> (FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a) -> b)
-> FiniteSupportStream a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (Stream (CT a)) (Stream (CT a)) (FiniteSupportStream a)
-> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a)
RevDiff
  (Dual (Tangent (FiniteSupportStream a)))
  (Dual (Tangent (FiniteSupportStream a)))
  (FiniteSupportStream a)
-> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a)
forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg

-- | Maybe differentiable value builder.
-- Creates a mapping function for Maybe types.
-- See [this tutorial section]
-- (Numeric-InfBackprop-Tutorial.html#g:multivalued-45-function-45-how-45-it-45-works)
-- for details and examples.
mkMaybeVal :: (a -> b) -> Maybe a -> Maybe b
mkMaybeVal :: forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal = (a -> b) -> Maybe a -> Maybe b
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap

-- | `Maybe` value structure for differentiable functions.
-- Extracts the derivative with respect to the original function for `Maybe` types.
--
-- ==== __Examples__
--
-- >>> :{
--  class SafeRecip a where
--    safeRecip :: a -> Maybe a
--  instance SafeRecip Float where
--    safeRecip x = if x == 0.0 then Nothing else Just (recip x)
--  instance (SafeRecip b, Subtractive b, Multiplicative b, IntegerPower b) =>
--    SafeRecip (RevDiff a b b) where
--      safeRecip (MkRevDiff v bp) =
--        fmap (\r -> MkRevDiff r (bp . negate . (r^2 *))) (safeRecip v)
-- :}
--
-- >>> safeRecip (2.0 :: Float) :: Maybe Float
-- Just 0.5
-- >>> safeRecip (0.0 :: Float) :: Maybe Float
-- Nothing
--
-- >>> customValDerivative maybeVal safeRecip (2.0 :: Float)
-- Just (-0.25)
-- >>> customValDerivative maybeVal safeRecip (0.0 :: Float)
-- Nothing
maybeVal ::
  (Multiplicative b) =>
  Maybe (RevDiff a b c) ->
  Maybe a
maybeVal :: forall b a c. Multiplicative b => Maybe (RevDiff a b c) -> Maybe a
maybeVal = (RevDiff a b c -> a) -> Maybe (RevDiff a b c) -> Maybe a
forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal

-- | `Maybe` instance of `AutoDifferentiableValue`.
instance
  (AutoDifferentiableValue a) =>
  AutoDifferentiableValue (Maybe a)
  where
  type DerivativeValue (Maybe a) = Maybe (DerivativeValue a)
  autoVal :: Maybe a -> Maybe (DerivativeValue a)
  autoVal :: Maybe a -> Maybe (DerivativeValue a)
autoVal = (a -> DerivativeValue a) -> Maybe a -> Maybe (DerivativeValue a)
forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal

-- | Argument descriptor for differentiable functions with optional argument.
-- Transforms a `RevDiff` instances of an otional type into
-- an optional of `RevDiff` instances.
-- This allows applying differentiable operations to the optiona value.

-- | Argument descriptor for differentiable functions with optional (`Maybe`) values.
--
-- Transforms a `RevDiff` instance containing an optional type into an optional
-- `RevDiff` instance. This transformation enables applying differentiable
-- operations to values that may or may not be present, while preserving
-- gradient flow when values exist.
--
-- When the wrapped value is `Just x`, the function extracts the value and
-- wraps it in a new `RevDiff` instance with appropriately transformed
-- backpropagation. When the value is `Nothing`, the result is `Nothing`,
-- effectively short-circuiting the computation.
--
-- ==== __Examples__
--
-- >>> :{
--  f :: Additive a => Maybe a -> a
--  f (Just x) = x
--  f Nothing = zero
-- :}
--
-- >>> customArgDerivative maybeArg f (Just 3 :: Maybe Float) :: Maybe Float
-- Just 1.0
maybeArg :: RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg :: forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg (MkRevDiff Maybe c
maybeX Maybe b -> a
bpc) = case Maybe c
maybeX of
  Just c
x -> RevDiff a b c -> Maybe (RevDiff a b c)
forall a. a -> Maybe a
Just (c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (Maybe b -> a
bpc (Maybe b -> a) -> (b -> Maybe b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Maybe b
forall a. a -> Maybe a
Just))
  Maybe c
Nothing -> Maybe (RevDiff a b c)
forall a. Maybe a
Nothing

-- | Maybe argument builder.
-- Applies a function to `Maybe` value obtained from a `RevDiff`.
mkMaybeArg ::
  (RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg :: forall a b c d.
(RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg RevDiff a b c -> d
f = (RevDiff a b c -> d) -> Maybe (RevDiff a b c) -> Maybe d
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiff a b c -> d
f (Maybe (RevDiff a b c) -> Maybe d)
-> (RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c))
-> RevDiff a (Maybe b) (Maybe c)
-> Maybe d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg

-- | `Maybe` instance of `AutoDifferentiableArgument`.
instance
  (AutoDifferentiableArgument a) =>
  AutoDifferentiableArgument (Maybe a)
  where
  type DerivativeRoot (Maybe a) = DerivativeRoot a
  type DerivativeCoarg (Maybe a) = Maybe (DerivativeCoarg a)
  type DerivativeArg (Maybe a) = Maybe (DerivativeArg a)
  autoArg :: RevDiff (DerivativeRoot a) (Maybe (DerivativeCoarg a)) (Maybe (DerivativeArg a)) -> Maybe a
  autoArg :: RevDiff
  (DerivativeRoot a)
  (Maybe (DerivativeCoarg a))
  (Maybe (DerivativeArg a))
-> Maybe a
autoArg = (RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
 -> a)
-> RevDiff
     (DerivativeRoot a)
     (Maybe (DerivativeCoarg a))
     (Maybe (DerivativeArg a))
-> Maybe a
forall a b c d.
(RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg

-- | Differentiable operator for functions that take a `Maybe` (a value or none) argument
-- and return any value type supported by `AutoDifferentiableValue`.
-- The output is `Maybe` of corresponding derivatives over the inner type.
--
-- ==== __Examples__
--
-- >>> import Debug.SimpleExpr (variable, simplify, SE)
-- >>> import Debug.DiffExpr (SymbolicFunc)
-- >>> import qualified GHC.Num as GHCN
--
-- >>> :{
--  maybeF :: TrigField a => Maybe a -> a
--  maybeF (Just x) = sin x
--  maybeF Nothing = zero
-- :}
--
-- >>> maybeF (Just 0.0 :: Maybe Float)
-- 0.0
--
-- >>> maybeF (Nothing :: Maybe Float)
-- 0.0
--
-- >>> maybeArgDerivative maybeF (Just 0.0 :: Maybe Float)
-- Just 1.0
--
-- >>> maybeArgDerivative maybeF (Nothing :: Maybe Float)
-- Just 0.0
maybeArgDerivative ::
  (AutoDifferentiableValue b) =>
  (Maybe (RevDiff' (Maybe a) a) -> b) ->
  Maybe a ->
  DerivativeValue b
maybeArgDerivative :: forall b a.
AutoDifferentiableValue b =>
(Maybe (RevDiff' (Maybe a) a) -> b) -> Maybe a -> DerivativeValue b
maybeArgDerivative = (RevDiff
   (Dual (Tangent (Maybe a))) (Dual (Tangent (Maybe a))) (Maybe a)
 -> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a))
-> (Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
    -> b)
-> Maybe a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
  (Maybe (Dual (Tangent a))) (Maybe (Dual (Tangent a))) (Maybe a)
-> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
RevDiff
  (Dual (Tangent (Maybe a))) (Dual (Tangent (Maybe a))) (Maybe a)
-> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg

-- | Derivative operator for functions with Maybe arguments.
-- This allows computing derivatives of functions that returns Maybe values as output,
-- handling the case when the value is Nothing appropriately.
--
-- ==== __Examples__
--
-- >>> :{
--  class SafeRecip a where
--    safeRecip :: a -> Maybe a
--  instance SafeRecip Float where
--    safeRecip x = if x == 0.0 then Nothing else Just (recip x)
--  instance (SafeRecip b, Subtractive b, Multiplicative b, IntegerPower b) =>
--    SafeRecip (RevDiff a b b) where
--      safeRecip (MkRevDiff v bp) =
--        fmap (\r -> MkRevDiff r (bp . negate . (r^2 *))) (safeRecip v)
-- :}
--
-- >>> safeRecip (2.0 :: Float) :: Maybe Float
-- Just 0.5
-- >>> safeRecip (0.0 :: Float) :: Maybe Float
-- Nothing
--
-- >>> maybeValDerivative safeRecip (2.0 :: Float)
-- Just (-0.25)
-- >>> maybeValDerivative safeRecip (0.0 :: Float)
-- Nothing
maybeValDerivative ::
  ( AutoDifferentiableArgument a,
    Multiplicative c,
    DerivativeCoarg a ~ CT (DerivativeArg a),
    DerivativeRoot a ~ CT (DerivativeArg a)
  ) =>
  (a -> Maybe (RevDiff b c d)) ->
  DerivativeArg a ->
  Maybe b
maybeValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
 DerivativeCoarg a ~ CT (DerivativeArg a),
 DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> Maybe (RevDiff b c d)) -> DerivativeArg a -> Maybe b
maybeValDerivative = (Maybe (RevDiff b c d) -> Maybe b)
-> (a -> Maybe (RevDiff b c d)) -> DerivativeArg a -> Maybe b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
 DerivativeCoarg b ~ CT (DerivativeArg b),
 AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative Maybe (RevDiff b c d) -> Maybe b
forall b a c. Multiplicative b => Maybe (RevDiff a b c) -> Maybe a
maybeVal