{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
module Numeric.InfBackprop.Core
(
Tangent,
Dual,
Cotangent,
CT,
RevDiff (MkRevDiff, value, backprop),
RevDiff',
DifferentiableFunc,
initDiff,
call,
derivativeOp,
toLensOps,
constDiff,
StopDiff (stopDiff),
HasConstant (constant),
simpleDifferentiableFunc,
toLens,
fromLens,
fromProfunctors,
toProfunctors,
fromVanLaarhoven,
toVanLaarhoven,
AutoDifferentiableArgument,
DerivativeRoot,
DerivativeCoarg,
DerivativeArg,
AutoDifferentiableValue,
DerivativeValue,
autoArg,
autoVal,
sameTypeDerivative,
simpleDerivative,
simpleValueAndDerivative,
customArgDerivative,
customValDerivative,
customArgValDerivative,
differentiableSum,
differentiableSub,
differentiableNegate,
differentiableMult,
differentiableDiv,
differentiableRecip,
differentiableMultAction,
differentiableConv,
differentiablePow,
differentiableExp,
differentiableLog,
differentiableLogBase,
differentiableSqrt,
differentiableSin,
differentiableCos,
differentiableTan,
differentiableSinh,
differentiableCosh,
differentiableTanh,
differentiableAsin,
differentiableAcos,
differentiableAtan,
differentiableAtan2,
differentiableAsinh,
differentiableAcosh,
differentiableAtanh,
scalarArg,
scalarVal,
scalarArgDerivative,
scalarValDerivative,
mkTupleArg,
tupleArg,
tupleArgDerivative,
tupleDerivativeOverX,
tupleDerivativeOverY,
twoArgsDerivative,
twoArgsDerivativeOverX,
twoArgsDerivativeOverY,
mkTupleVal,
tupleVal,
tupleValDerivative,
threeArgsToTriple,
tripleArg,
mkTripleArg,
tripleArgDerivative,
tripleDerivativeOverX,
tripleDerivativeOverY,
tripleDerivativeOverZ,
threeArgsDerivative,
derivative3ArgsOverX,
derivative3ArgsOverY,
derivative3ArgsOverZ,
mkTripleVal,
tripleVal,
tripleValDerivative,
boxedVectorArg,
mkBoxedVectorArg,
boxedVectorArgDerivative,
boxedVectorVal,
mkBoxedVectorVal,
boxedVectorValDerivative,
streamArg,
mkStreamArg,
streamArgDerivative,
streamVal,
mkStreamVal,
streamValDerivative,
finiteSupportStreamArg,
mkFiniteSupportStreamArg,
finiteSupportStreamArgDerivative,
finiteSupportStreamVal,
mkFiniteSupportStreamVal,
finiteSupportStreamValDerivative,
maybeArg,
mkMaybeArg,
maybeArgDerivative,
maybeVal,
mkMaybeVal,
maybeValDerivative,
)
where
import Control.Applicative ((<$>), (<*>))
import Control.Comonad.Identity (Identity (Identity, runIdentity))
import Control.ExtendableMap (ExtandableMap, extendMap)
import qualified Control.Lens as CL
import Control.Monad.ST (runST)
import Data.Bifunctor (first)
import Data.Coerce (coerce)
import Data.Composition ((.:))
import Data.Finite (Finite)
import Data.FiniteSupportStream (FiniteSupportStream (MkFiniteSupportStream, toVector), cons, empty, head, singleton, tail, unsafeMap)
import Data.Function (on)
import Data.Functor.Compose (Compose (Compose, getCompose))
import Data.Functor.Const (Const (Const, getConst))
import Data.Int (Int16, Int32, Int64, Int8)
import Data.List.NonEmpty (xor)
import Data.Primitive (Prim)
import Data.Profunctor (Profunctor (dimap))
import Data.Profunctor.Strong (Costrong (unfirst, unsecond))
import Data.Proxy (Proxy (Proxy))
import Data.Stream (Stream)
import qualified Data.Stream as DS
import Data.Tuple (curry, fst, snd, uncurry)
import Data.Tuple.Extra ((***))
import Data.Type.Equality (type (~))
import qualified Data.Vector as DV
import qualified Data.Vector.Fixed.Boxed as DVFB
import Data.Vector.Fusion.Util (Box (Box, unBox))
import qualified Data.Vector.Generic as DVG
import Data.Vector.Generic.Base
( Vector
( basicLength,
basicUnsafeCopy,
basicUnsafeFreeze,
basicUnsafeIndexM,
basicUnsafeSlice,
basicUnsafeThaw,
elemseq
),
)
import qualified Data.Vector.Generic.Base as DVGB
import qualified Data.Vector.Generic.Mutable as DVGM
import qualified Data.Vector.Generic.Mutable.Base as DVGBM
import qualified Data.Vector.Generic.Sized as DVGS
import qualified Data.Vector.Generic.Sized.Internal as DVGSI
import qualified Data.Vector.Primitive as DVP
import qualified Data.Vector.Unboxed as DVU
import qualified Data.Vector.Unboxed.Mutable as DVUM
import Data.Word (Word, Word16, Word32, Word64, Word8)
import Debug.SimpleExpr (SimpleExpr, SimpleExprF)
import Debug.SimpleExpr.Expr (SE, number)
import Debug.SimpleExpr.Utils.Algebra
( AlgebraicPower ((^^)),
Convolution ((|*|)),
IntegerPower,
MultiplicativeAction ((*|)),
(^),
)
import Debug.SimpleExpr.Utils.Traced (Traced (MkTraced))
import Debug.Trace (trace)
import Foreign (oneBits)
import GHC.Base
( Applicative,
Eq ((==)),
Float,
Functor,
Int,
Maybe (Just, Nothing),
Ord (compare, max, min, (<), (<=), (>), (>=)),
Type,
const,
flip,
fmap,
id,
pure,
return,
undefined,
($),
(++),
(.),
(<*>),
)
import GHC.Generics (C, Generic, type (:.:) (unComp1))
import GHC.Integer (Integer)
import GHC.Natural (Natural)
import qualified GHC.Num as GHCN
import GHC.Real (Integral, fromIntegral, realToFrac, toInteger)
import qualified GHC.Real as GHCR
import GHC.Show (Show (show))
import GHC.TypeLits (KnownChar)
import GHC.TypeNats (KnownNat, Nat)
import GHC.Types (Int)
import NumHask
( Additive,
AdditiveAction,
Complex,
Distributive,
Divisive,
ExpField,
Field,
FromInteger (fromInteger),
FromIntegral,
Multiplicative,
Subtractive,
TrigField,
acos,
acosh,
asin,
asinh,
atan,
atan2,
atanh,
cos,
cosh,
exp,
fromIntegral,
log,
logBase,
negate,
one,
pi,
recip,
sin,
sinh,
sqrt,
tan,
tanh,
two,
zero,
(*),
(**),
(+),
(-),
(/),
)
import NumHask.Data.Integral (FromInteger)
import Numeric.InfBackprop.Instances.NumHask ()
import Numeric.InfBackprop.Utils.SizedVector (BoxedVector, boxedVectorBasis, boxedVectorSum)
import Numeric.InfBackprop.Utils.Tuple (cross, cross3, curry3, fork, fork3, uncurry3)
import Optics (Lens, Lens', getting, lens, set, simple, view, (%))
type family Tangent (a :: Type) :: Type
type instance Tangent Float = Float
type instance Tangent GHCN.Integer = GHCN.Integer
type instance Tangent SimpleExpr = SimpleExpr
type instance Tangent (a0, a1) = (Tangent a0, Tangent a1)
type instance Tangent (a0, a1, a2) = (Tangent a0, Tangent a1, Tangent a2)
type instance Tangent [a] = [Tangent a]
type instance Tangent (DVFB.Vec n a) = DVFB.Vec n (Tangent a)
type instance Tangent (DVGS.Vector v n a) = DVGS.Vector v n (Tangent a)
type instance Tangent (Stream a) = Stream (Tangent a)
type instance Tangent (FiniteSupportStream a) = FiniteSupportStream (Tangent a)
type instance Tangent (Maybe a) = Maybe (Tangent a)
type instance Tangent (Traced a) = Traced (Tangent a)
type instance Tangent (Complex a) = Complex (Tangent a)
type family Dual (x :: Type) :: Type
type instance Dual Float = Float
type instance Dual GHCN.Integer = GHCN.Integer
type instance Dual SimpleExpr = SimpleExpr
type instance Dual (a, b) = (Dual a, Dual b)
type instance Dual (a, b, c) = (Dual a, Dual b, Dual c)
type instance Dual [a] = [Dual a]
type instance Dual (DVFB.Vec n a) = DVFB.Vec n (Dual a)
type instance Dual (DVGS.Vector v n a) = DVGS.Vector v n (Dual a)
type instance Dual (Stream a) = FiniteSupportStream (Dual a)
type instance Dual (FiniteSupportStream a) = Stream (Dual a)
type instance Dual (SimpleExprF a) = SimpleExprF (Dual a)
type instance Dual (Maybe a) = Maybe (Dual a)
type instance Dual (Traced a) = Traced (Dual a)
type instance Dual (Complex a) = Complex (Dual a)
type Cotangent a = Dual (Tangent a)
type CT a = Cotangent a
data RevDiff a b c = MkRevDiff {forall a b c. RevDiff a b c -> c
value :: c, forall a b c. RevDiff a b c -> b -> a
backprop :: b -> a}
deriving ((forall x. RevDiff a b c -> Rep (RevDiff a b c) x)
-> (forall x. Rep (RevDiff a b c) x -> RevDiff a b c)
-> Generic (RevDiff a b c)
forall x. Rep (RevDiff a b c) x -> RevDiff a b c
forall x. RevDiff a b c -> Rep (RevDiff a b c) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a b c x. Rep (RevDiff a b c) x -> RevDiff a b c
forall a b c x. RevDiff a b c -> Rep (RevDiff a b c) x
$cfrom :: forall a b c x. RevDiff a b c -> Rep (RevDiff a b c) x
from :: forall x. RevDiff a b c -> Rep (RevDiff a b c) x
$cto :: forall a b c x. Rep (RevDiff a b c) x -> RevDiff a b c
to :: forall x. Rep (RevDiff a b c) x -> RevDiff a b c
Generic)
type RevDiff' a b = RevDiff (CT a) (CT b) b
type instance Tangent (RevDiff a b c) = RevDiff a (Tangent b) (Tangent c)
type instance Dual (RevDiff a b c) = RevDiff a (Dual b) (Dual c)
call :: (RevDiff' a a -> RevDiff' a b) -> a -> b
call :: forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> b
call RevDiff' a a -> RevDiff' a b
f = RevDiff' a b -> b
forall a b c. RevDiff a b c -> c
value (RevDiff' a b -> b) -> (a -> RevDiff' a b) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a a -> RevDiff' a b
f (RevDiff' a a -> RevDiff' a b)
-> (a -> RevDiff' a a) -> a -> RevDiff' a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff
derivativeOp :: (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp :: forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp RevDiff' a a -> RevDiff' a b
f = RevDiff' a b -> Dual (Tangent b) -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff' a b -> Dual (Tangent b) -> Dual (Tangent a))
-> (a -> RevDiff' a b) -> a -> Dual (Tangent b) -> Dual (Tangent a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a a -> RevDiff' a b
f (RevDiff' a a -> RevDiff' a b)
-> (a -> RevDiff' a a) -> a -> RevDiff' a b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff
toLensOps :: (RevDiff ca ca a -> RevDiff ca cb b) -> a -> (b, cb -> ca)
toLensOps :: forall ca a cb b.
(RevDiff ca ca a -> RevDiff ca cb b) -> a -> (b, cb -> ca)
toLensOps RevDiff ca ca a -> RevDiff ca cb b
f a
x = (b
y, cb -> ca
bp)
where
MkRevDiff b
y cb -> ca
bp = RevDiff ca ca a -> RevDiff ca cb b
f (RevDiff ca ca a -> RevDiff ca cb b)
-> RevDiff ca ca a -> RevDiff ca cb b
forall a b. (a -> b) -> a -> b
$ a -> RevDiff ca ca a
forall a b. a -> RevDiff b b a
initDiff a
x
simpleDifferentiableFunc ::
(Multiplicative b) =>
(b -> b) ->
(b -> b) ->
RevDiff a b b ->
RevDiff a b b
simpleDifferentiableFunc :: forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
f b -> b
f' (MkRevDiff b
x b -> a
bpc) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
f b
x) (\b
cy -> b -> a
bpc (b -> a) -> b -> a
forall a b. (a -> b) -> a -> b
$ b -> b
f' b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy)
initDiff :: a -> RevDiff b b a
initDiff :: forall a b. a -> RevDiff b b a
initDiff a
x = a -> (b -> b) -> RevDiff b b a
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff a
x b -> b
forall a. a -> a
id
toLens :: (RevDiff b b a -> RevDiff b d c) -> Lens a b c d
toLens :: forall b a d c. (RevDiff b b a -> RevDiff b d c) -> Lens a b c d
toLens RevDiff b b a -> RevDiff b d c
f = (a -> c) -> (a -> d -> b) -> Lens a b c d
forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens (RevDiff b d c -> c
forall a b c. RevDiff a b c -> c
value (RevDiff b d c -> c) -> (a -> RevDiff b d c) -> a -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b d c
bp) (RevDiff b d c -> d -> b
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff b d c -> d -> b) -> (a -> RevDiff b d c) -> a -> d -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b d c
bp)
where
bp :: a -> RevDiff b d c
bp = RevDiff b b a -> RevDiff b d c
f (RevDiff b b a -> RevDiff b d c)
-> (a -> RevDiff b b a) -> a -> RevDiff b d c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff b b a
forall a b. a -> RevDiff b b a
initDiff
fromLens :: Lens a (CT a) b (CT b) -> RevDiff' a a -> RevDiff' a b
fromLens :: forall a b. Lens a (CT a) b (CT b) -> RevDiff' a a -> RevDiff' a b
fromLens Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l (MkRevDiff a
x Dual (Tangent a) -> Dual (Tangent a)
bp) = b
-> (Dual (Tangent b) -> Dual (Tangent a))
-> RevDiff (Dual (Tangent a)) (Dual (Tangent b)) b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((Optic' A_Getter NoIx a b -> a -> b
forall k (is :: IxList) s a.
Is k A_Getter =>
Optic' k is s a -> s -> a
view (Optic' A_Getter NoIx a b -> a -> b)
-> (Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' A_Getter NoIx a b)
-> Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> a
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' A_Getter NoIx a b
Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' (ReadOnlyOptic A_Lens) NoIx a b
forall (is :: IxList).
Optic A_Lens is a (Dual (Tangent a)) b (Dual (Tangent b))
-> Optic' (ReadOnlyOptic A_Lens) is a b
forall k s t a b (is :: IxList).
ToReadOnly k s t a b =>
Optic k is s t a b -> Optic' (ReadOnlyOptic k) is s a
getting) Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l a
x) (\Dual (Tangent b)
cy -> Dual (Tangent a) -> Dual (Tangent a)
bp (Dual (Tangent a) -> Dual (Tangent a))
-> Dual (Tangent a) -> Dual (Tangent a)
forall a b. (a -> b) -> a -> b
$ Lens a (Dual (Tangent a)) b (Dual (Tangent b))
-> Dual (Tangent b) -> a -> Dual (Tangent a)
forall k (is :: IxList) s t a b.
Is k A_Setter =>
Optic k is s t a b -> b -> s -> t
set Lens a (Dual (Tangent a)) b (Dual (Tangent b))
l Dual (Tangent b)
cy a
x)
instance Profunctor (RevDiff t) where
dimap :: (a -> b) -> (c -> d) -> RevDiff t b c -> RevDiff t a d
dimap :: forall a b c d.
(a -> b) -> (c -> d) -> RevDiff t b c -> RevDiff t a d
dimap a -> b
f c -> d
g (MkRevDiff c
v b -> t
bp) = d -> (a -> t) -> RevDiff t a d
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c -> d
g c
v) (b -> t
bp (b -> t) -> (a -> b) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)
instance Costrong (RevDiff t) where
unfirst :: RevDiff t (a, d) (b, d) -> RevDiff t a b
unfirst :: forall a d b. RevDiff t (a, d) (b, d) -> RevDiff t a b
unfirst (MkRevDiff (b, d)
v (a, d) -> t
bp) = b -> (a -> t) -> RevDiff t a b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((b, d) -> b
forall a b. (a, b) -> a
fst (b, d)
v) ((a, d) -> t
bp ((a, d) -> t) -> (a -> (a, d)) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,(b, d) -> d
forall a b. (a, b) -> b
snd (b, d)
v))
unsecond :: RevDiff t (d, a) (d, b) -> RevDiff t a b
unsecond :: forall d a b. RevDiff t (d, a) (d, b) -> RevDiff t a b
unsecond (MkRevDiff (d, b)
v (d, a) -> t
bp) = b -> (a -> t) -> RevDiff t a b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff ((d, b) -> b
forall a b. (a, b) -> b
snd (d, b)
v) ((d, a) -> t
bp ((d, a) -> t) -> (a -> (d, a)) -> a -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((d, b) -> d
forall a b. (a, b) -> a
fst (d, b)
v,))
type DifferentiableFunc a b = forall t. RevDiff t (CT a) a -> RevDiff t (CT b) b
fromProfunctors ::
(forall p. (Costrong p) => p (CT a) a -> p (CT b) b) -> DifferentiableFunc a b
fromProfunctors :: forall a b.
(forall (p :: * -> * -> *). Costrong p => p (CT a) a -> p (CT b) b)
-> DifferentiableFunc a b
fromProfunctors = (RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b)
-> RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b
(forall (p :: * -> * -> *).
Costrong p =>
p (Dual (Tangent a)) a -> p (Dual (Tangent b)) b)
-> RevDiff t (Dual (Tangent a)) a -> RevDiff t (Dual (Tangent b)) b
forall a. a -> a
id
toProfunctors ::
(Costrong p) =>
DifferentiableFunc a b ->
p (CT a) a ->
p (CT b) b
toProfunctors :: forall (p :: * -> * -> *) a b.
Costrong p =>
DifferentiableFunc a b -> p (CT a) a -> p (CT b) b
toProfunctors DifferentiableFunc a b
f = p (a, Dual (Tangent b)) (a, b) -> p (Dual (Tangent b)) b
forall d a b. p (d, a) (d, b) -> p a b
forall (p :: * -> * -> *) d a b.
Costrong p =>
p (d, a) (d, b) -> p a b
unsecond (p (a, Dual (Tangent b)) (a, b) -> p (Dual (Tangent b)) b)
-> (p (Dual (Tangent a)) a -> p (a, Dual (Tangent b)) (a, b))
-> p (Dual (Tangent a)) a
-> p (Dual (Tangent b)) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Dual (Tangent b)) -> Dual (Tangent a))
-> (a -> (a, b))
-> p (Dual (Tangent a)) a
-> p (a, Dual (Tangent b)) (a, b)
forall a b c d. (a -> b) -> (c -> d) -> p b c -> p a d
forall (p :: * -> * -> *) a b c d.
Profunctor p =>
(a -> b) -> (c -> d) -> p b c -> p a d
dimap ((a -> Dual (Tangent b) -> Dual (Tangent a))
-> (a, Dual (Tangent b)) -> Dual (Tangent a)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> Dual (Tangent b) -> Dual (Tangent a)
u) ((a -> a) -> (a -> b) -> a -> (a, b)
forall t a b. (t -> a) -> (t -> b) -> t -> (a, b)
fork a -> a
forall a. a -> a
id a -> b
v)
where
v :: a -> b
v = (RevDiff' a a -> RevDiff' a b) -> a -> b
forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> b
call RevDiff' a a -> RevDiff' a b
DifferentiableFunc a b
f
u :: a -> Dual (Tangent b) -> Dual (Tangent a)
u = (RevDiff' a a -> RevDiff' a b)
-> a -> Dual (Tangent b) -> Dual (Tangent a)
forall a b. (RevDiff' a a -> RevDiff' a b) -> a -> CT b -> CT a
derivativeOp RevDiff' a a -> RevDiff' a b
DifferentiableFunc a b
f
fromVanLaarhoven ::
forall a b.
(forall f. (Functor f) => (b -> f (CT b)) -> a -> f (CT a)) ->
DifferentiableFunc a b
fromVanLaarhoven :: forall a b.
(forall (f :: * -> *).
Functor f =>
(b -> f (CT b)) -> a -> f (CT a))
-> DifferentiableFunc a b
fromVanLaarhoven forall (f :: * -> *).
Functor f =>
(b -> f (Dual (Tangent b))) -> a -> f (CT a)
vll (MkRevDiff a
x CT a -> t
bpx) = b -> (Dual (Tangent b) -> t) -> RevDiff t (Dual (Tangent b)) b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (CT a -> t
bpx (CT a -> t) -> (Dual (Tangent b) -> CT a) -> Dual (Tangent b) -> t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dual (Tangent b) -> CT a
bp)
where
(b
y, Dual (Tangent b) -> CT a
bp) = Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
-> (b, Dual (Tangent b) -> CT a)
forall {k1} {k2} (f :: k1 -> *) (g :: k2 -> k1) (a :: k2).
Compose f g a -> f (g a)
getCompose (Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
-> (b, Dual (Tangent b) -> CT a))
-> Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
-> (b, Dual (Tangent b) -> CT a)
forall a b. (a -> b) -> a -> b
$ (b -> Compose ((,) b) ((->) (Dual (Tangent b))) (Dual (Tangent b)))
-> a -> Compose ((,) b) ((->) (Dual (Tangent b))) (CT a)
forall (f :: * -> *).
Functor f =>
(b -> f (Dual (Tangent b))) -> a -> f (CT a)
vll (\b
y_ -> (b, Dual (Tangent b) -> Dual (Tangent b))
-> Compose ((,) b) ((->) (Dual (Tangent b))) (Dual (Tangent b))
forall {k} {k1} (f :: k -> *) (g :: k1 -> k) (a :: k1).
f (g a) -> Compose f g a
Compose (b
y_, Dual (Tangent b) -> Dual (Tangent b)
forall a. a -> a
id)) a
x
toVanLaarhoven ::
(Functor f) =>
DifferentiableFunc a b ->
(b -> f (CT b)) ->
a ->
f (CT a)
toVanLaarhoven :: forall (f :: * -> *) a b.
Functor f =>
DifferentiableFunc a b -> (b -> f (CT b)) -> a -> f (CT a)
toVanLaarhoven DifferentiableFunc a b
g b -> f (CT b)
f a
x = (CT b -> Dual (Tangent a)) -> f (CT b) -> f (Dual (Tangent a))
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CT b -> Dual (Tangent a)
bp (b -> f (CT b)
f b
y)
where
MkRevDiff b
y CT b -> Dual (Tangent a)
bp = RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (CT b) b
DifferentiableFunc a b
g (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (CT b) b)
-> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (CT b) b
forall a b. (a -> b) -> a -> b
$ a -> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
forall a b. a -> RevDiff b b a
initDiff a
x
constDiff :: (Additive a) => c -> RevDiff a b c
constDiff :: forall a c b. Additive a => c -> RevDiff a b c
constDiff c
x = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (a -> b -> a
forall a b. a -> b -> a
const a
forall a. Additive a => a
zero)
simpleDerivative ::
forall a b.
(Multiplicative (CT b)) =>
(RevDiff' a a -> RevDiff' a b) ->
a ->
CT a
simpleDerivative :: forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> CT a
simpleDerivative RevDiff' a a -> RevDiff' a b
f a
x = RevDiff' a b -> CT b -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop (RevDiff' a a -> RevDiff' a b
f (a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff a
x)) CT b
forall a. Multiplicative a => a
one
sameTypeDerivative ::
(Multiplicative (CT a)) =>
(RevDiff (CT a) (CT a) a -> RevDiff (CT a) (CT a) a) ->
a ->
CT a
sameTypeDerivative :: forall a.
Multiplicative (CT a) =>
(RevDiff (CT a) (CT a) a -> RevDiff (CT a) (CT a) a) -> a -> CT a
sameTypeDerivative = (RevDiff' a a -> RevDiff' a a) -> a -> CT a
forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> CT a
simpleDerivative
simpleValueAndDerivative ::
forall a b.
(Multiplicative (CT b)) =>
(RevDiff' a a -> RevDiff' a b) ->
a ->
(b, CT a)
simpleValueAndDerivative :: forall a b.
Multiplicative (CT b) =>
(RevDiff' a a -> RevDiff' a b) -> a -> (b, CT a)
simpleValueAndDerivative RevDiff' a a -> RevDiff' a b
f a
x = (RevDiff' a b -> b
forall a b c. RevDiff a b c -> c
value RevDiff' a b
out, RevDiff' a b -> CT b -> Dual (Tangent a)
forall a b c. RevDiff a b c -> b -> a
backprop RevDiff' a b
out CT b
forall a. Multiplicative a => a
one)
where
out :: RevDiff' a b
out = RevDiff' a a -> RevDiff' a b
f (a -> RevDiff' a a
forall a b. a -> RevDiff b b a
initDiff a
x)
customArgValDerivative ::
(RevDiff (CT a) (CT a) a -> b) ->
(c -> d) ->
(b -> c) ->
a ->
d
customArgValDerivative :: forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (CT a) (CT a) a -> b
argTerm c -> d
valTerm b -> c
f = c -> d
valTerm (c -> d) -> (a -> c) -> a -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> c
f (b -> c) -> (a -> b) -> a -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a) (CT a) a -> b
argTerm (RevDiff (CT a) (CT a) a -> b)
-> (a -> RevDiff (CT a) (CT a) a) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> RevDiff (CT a) (CT a) a
forall a b. a -> RevDiff b b a
initDiff
type RevDiffArg a b c d = RevDiff a b c -> d
class
(Additive (DerivativeRoot a), Additive (DerivativeCoarg a)) =>
AutoDifferentiableArgument a
where
type DerivativeRoot a :: Type
type DerivativeCoarg a :: Type
type DerivativeArg a :: Type
autoArg :: RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) -> a
instance
(Additive a, Additive b) =>
AutoDifferentiableArgument (RevDiff a b c)
where
type DerivativeRoot (RevDiff a b c) = a
type DerivativeCoarg (RevDiff a b c) = b
type DerivativeArg (RevDiff a b c) = c
autoArg :: RevDiff
(DerivativeRoot (RevDiff a b c))
(DerivativeCoarg (RevDiff a b c))
(DerivativeArg (RevDiff a b c))
-> RevDiff a b c
autoArg = RevDiff a b c -> RevDiff a b c
RevDiff
(DerivativeRoot (RevDiff a b c))
(DerivativeCoarg (RevDiff a b c))
(DerivativeArg (RevDiff a b c))
-> RevDiff a b c
forall a. a -> a
id
class AutoDifferentiableValue a where
type DerivativeValue a :: Type
autoVal :: a -> DerivativeValue a
scalarVal ::
(Multiplicative b) =>
RevDiff a b c ->
a
scalarVal :: forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal (MkRevDiff c
_ b -> a
bp) = b -> a
bp b
forall a. Multiplicative a => a
one
instance
(Multiplicative b) =>
AutoDifferentiableValue (RevDiff a b c)
where
type DerivativeValue (RevDiff a b c) = a
autoVal :: RevDiff a b c -> a
autoVal :: RevDiff a b c -> a
autoVal = RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
customArgDerivative ::
(AutoDifferentiableValue c) =>
(RevDiff (CT a) (CT a) a -> b) ->
(b -> c) ->
a ->
DerivativeValue c
customArgDerivative :: forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a) (CT a) a -> b
arg = (RevDiff (CT a) (CT a) a -> b)
-> (c -> DerivativeValue c) -> (b -> c) -> a -> DerivativeValue c
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (CT a) (CT a) a -> b
arg c -> DerivativeValue c
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
customValDerivative ::
( DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b
) =>
(c -> d) ->
(b -> c) ->
DerivativeArg b ->
d
customValDerivative :: forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative = (RevDiff
(CT (DerivativeArg b)) (CT (DerivativeArg b)) (DerivativeArg b)
-> b)
-> (c -> d) -> (b -> c) -> DerivativeArg b -> d
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (DerivativeRoot b) (DerivativeCoarg b) (DerivativeArg b)
-> b
RevDiff
(CT (DerivativeArg b)) (CT (DerivativeArg b)) (DerivativeArg b)
-> b
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
scalarArg :: RevDiff a b c -> RevDiff a b c
scalarArg :: forall a b c. RevDiff a b c -> RevDiff a b c
scalarArg = RevDiff a b c -> RevDiff a b c
forall a. a -> a
id
scalarArgDerivative ::
(AutoDifferentiableValue c) =>
(RevDiff' a a -> c) ->
a ->
DerivativeValue c
scalarArgDerivative :: forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative = (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a)
-> (c -> DerivativeValue c)
-> (RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a -> c)
-> a
-> DerivativeValue c
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
-> RevDiff (Dual (Tangent a)) (Dual (Tangent a)) a
forall a. a -> a
id c -> DerivativeValue c
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
scalarValDerivative ::
( DerivativeRoot b ~ CT a,
DerivativeCoarg b ~ CT a,
DerivativeArg b ~ a,
Multiplicative (CT c),
AutoDifferentiableArgument b
) =>
(b -> RevDiff d (CT c) c) ->
a ->
d
scalarValDerivative :: forall b a c d.
(DerivativeRoot b ~ CT a, DerivativeCoarg b ~ CT a,
DerivativeArg b ~ a, Multiplicative (CT c),
AutoDifferentiableArgument b) =>
(b -> RevDiff d (CT c) c) -> a -> d
scalarValDerivative = (RevDiff (CT a) (CT a) a -> b)
-> (RevDiff d (CT c) c -> d) -> (b -> RevDiff d (CT c) c) -> a -> d
forall a b c d.
(RevDiff (CT a) (CT a) a -> b) -> (c -> d) -> (b -> c) -> a -> d
customArgValDerivative RevDiff (DerivativeRoot b) (DerivativeCoarg b) (DerivativeArg b)
-> b
RevDiff (CT a) (CT a) a -> b
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiff d (CT c) c -> d
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance (Show (b -> a), Show c) => Show (RevDiff a b c) where
show :: RevDiff a b c -> String
show (MkRevDiff c
x b -> a
bp) = String
"MkRevDiff " String -> ShowS
forall a. [a] -> [a] -> [a]
++ c -> String
forall a. Show a => a -> String
show c
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (b -> a) -> String
forall a. Show a => a -> String
show b -> a
bp
class StopDiff a b where
stopDiff :: a -> b
instance StopDiff a a where
stopDiff :: a -> a
stopDiff = a -> a
forall a. a -> a
id
instance
(StopDiff a d, Additive b) =>
StopDiff a (RevDiff b c d)
where
stopDiff :: a -> RevDiff b c d
stopDiff = d -> RevDiff b c d
forall a c b. Additive a => c -> RevDiff a b c
constDiff (d -> RevDiff b c d) -> (a -> d) -> a -> RevDiff b c d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> d
forall a b. StopDiff a b => a -> b
stopDiff
class HasConstant a b c d where
constant :: Proxy a -> b -> c -> d
instance HasConstant a b a b where
constant :: Proxy a -> b -> a -> b
constant Proxy a
_ b
x a
_ = b
x
instance
forall a b c d e f t.
(HasConstant a b c d, Additive t, e ~ CT c, f ~ CT d) =>
HasConstant a b (RevDiff t e c) (RevDiff t f d)
where
constant :: Proxy a -> b -> RevDiff t e c -> RevDiff t f d
constant Proxy a
_ b
x (MkRevDiff c
v e -> t
_) = d -> RevDiff t f d
forall a c b. Additive a => c -> RevDiff a b c
constDiff (d -> RevDiff t f d) -> d -> RevDiff t f d
forall a b. (a -> b) -> a -> b
$ Proxy a -> b -> c -> d
forall a b c d. HasConstant a b c d => Proxy a -> b -> c -> d
constant (forall t. Proxy t
forall {k} (t :: k). Proxy t
Proxy @a) b
x c
v
differentiableSum ::
(Additive c) =>
RevDiff a (b, b) (c, c) ->
RevDiff a b c
differentiableSum :: forall c a b.
Additive c =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSum (MkRevDiff (c
x0, c
x1) (b, b) -> a
bpc) =
c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c
x0 c -> c -> c
forall a. Additive a => a -> a -> a
+ c
x1) (\b
cy -> (b, b) -> a
bpc (b
cy, b
cy))
instance
(Additive a, Additive c) =>
Additive (RevDiff a b c)
where
zero :: RevDiff a b c
zero = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff c
forall a. Additive a => a
zero
+ :: RevDiff a b c -> RevDiff a b c -> RevDiff a b c
(+) = RevDiff a (b, b) (c, c) -> RevDiff a b c
forall c a b.
Additive c =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSum (RevDiff a (b, b) (c, c) -> RevDiff a b c)
-> (RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c))
-> RevDiff a b c
-> RevDiff a b c
-> RevDiff a b c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
differentiableSub ::
(Subtractive b, Subtractive c) =>
RevDiff a (b, b) (c, c) ->
RevDiff a b c
differentiableSub :: forall b c a.
(Subtractive b, Subtractive c) =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSub (MkRevDiff (c
x0, c
x1) (b, b) -> a
bpc) =
c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c
x0 c -> c -> c
forall a. Subtractive a => a -> a -> a
- c
x1) (\b
cy -> (b, b) -> a
bpc (b
cy, b -> b
forall a. Subtractive a => a -> a
negate b
cy))
differentiableNegate ::
(Subtractive a, Subtractive c) =>
RevDiff a b c ->
RevDiff a b c
differentiableNegate :: forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate (MkRevDiff c
x b -> a
bp) = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c -> c
forall a. Subtractive a => a -> a
negate c
x) (a -> a
forall a. Subtractive a => a -> a
negate (a -> a) -> (b -> a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
bp)
instance
( Additive a,
Subtractive a,
Subtractive b,
Subtractive c
) =>
Subtractive (RevDiff a b c)
where
negate :: RevDiff a b c -> RevDiff a b c
negate = RevDiff a b c -> RevDiff a b c
forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate
(-) = RevDiff a (b, b) (c, c) -> RevDiff a b c
forall b c a.
(Subtractive b, Subtractive c) =>
RevDiff a (b, b) (c, c) -> RevDiff a b c
differentiableSub (RevDiff a (b, b) (c, c) -> RevDiff a b c)
-> (RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c))
-> RevDiff a b c
-> RevDiff a b c
-> RevDiff a b c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b c -> RevDiff a b c -> RevDiff a (b, b) (c, c)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
differentiableMult ::
(Multiplicative b) =>
RevDiff a (b, b) (b, b) ->
RevDiff a b b
differentiableMult :: forall b a.
Multiplicative b =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableMult (MkRevDiff (b
x0, b
x1) (b, b) -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b
x0 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
x1) (\b
cy -> (b, b) -> a
bpc (b
x1 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b
x0 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
instance
(Additive a, Multiplicative b) =>
Multiplicative (RevDiff a b b)
where
one :: RevDiff a b b
one = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff b
forall a. Multiplicative a => a
one
* :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(*) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
Multiplicative b =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableMult (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
instance
(MultiplicativeAction Integer b, MultiplicativeAction Integer cb) =>
MultiplicativeAction Integer (RevDiff ct cb b)
where
Integer
c *| :: Integer -> RevDiff ct cb b -> RevDiff ct cb b
*| (MkRevDiff b
x cb -> ct
bp) = b -> (cb -> ct) -> RevDiff ct cb b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (Integer
c Integer -> b -> b
forall a b. MultiplicativeAction a b => a -> b -> b
*| b
x) (cb -> ct
bp (cb -> ct) -> (cb -> cb) -> cb -> ct
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer
c Integer -> cb -> cb
forall a b. MultiplicativeAction a b => a -> b -> b
*|))
differentiableMultAction ::
(MultiplicativeAction a b, MultiplicativeAction a cb, Convolution b cb ca) =>
RevDiff ct (ca, cb) (a, b) ->
RevDiff ct cb b
differentiableMultAction :: forall a b cb ca ct.
(MultiplicativeAction a b, MultiplicativeAction a cb,
Convolution b cb ca) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
differentiableMultAction (MkRevDiff (a
x, b
y) (ca, cb) -> ct
bpc) =
b -> (cb -> ct) -> RevDiff ct cb b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (a
x a -> b -> b
forall a b. MultiplicativeAction a b => a -> b -> b
*| b
y) (\cb
cz -> (ca, cb) -> ct
bpc (b
y b -> cb -> ca
forall a b c. Convolution a b c => a -> b -> c
|*| cb
cz, a
x a -> cb -> cb
forall a b. MultiplicativeAction a b => a -> b -> b
*| cb
cz))
instance
(MultiplicativeAction a b, MultiplicativeAction a cb, Convolution b cb ca, Additive ct) =>
MultiplicativeAction (RevDiff ct ca a) (RevDiff ct cb b)
where
*| :: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct cb b
(*|) = RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
forall a b cb ca ct.
(MultiplicativeAction a b, MultiplicativeAction a cb,
Convolution b cb ca) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b
differentiableMultAction (RevDiff ct (ca, cb) (a, b) -> RevDiff ct cb b)
-> (RevDiff ct ca a
-> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b))
-> RevDiff ct ca a
-> RevDiff ct cb b
-> RevDiff ct cb b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
differentiableConv ::
(Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
RevDiff ct (ca, cb) (a, b) ->
RevDiff ct cc c
differentiableConv :: forall a b c cc ca cb ct.
(Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
differentiableConv (MkRevDiff (a
x, b
y) (ca, cb) -> ct
bpc) =
c -> (cc -> ct) -> RevDiff ct cc c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (a
x a -> b -> c
forall a b c. Convolution a b c => a -> b -> c
|*| b
y) (\cc
cz -> (ca, cb) -> ct
bpc (cc
cz cc -> b -> ca
forall a b c. Convolution a b c => a -> b -> c
|*| b
y, a
x a -> cc -> cb
forall a b c. Convolution a b c => a -> b -> c
|*| cc
cz))
instance
(Convolution a b c, Convolution cc b ca, Convolution a cc cb, Additive ct) =>
Convolution (RevDiff ct ca a) (RevDiff ct cb b) (RevDiff ct cc c)
where
|*| :: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct cc c
(|*|) = RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
forall a b c cc ca cb ct.
(Convolution a b c, Convolution cc b ca, Convolution a cc cb) =>
RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c
differentiableConv (RevDiff ct (ca, cb) (a, b) -> RevDiff ct cc c)
-> (RevDiff ct ca a
-> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b))
-> RevDiff ct ca a
-> RevDiff ct cb b
-> RevDiff ct cc c
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff ct ca a -> RevDiff ct cb b -> RevDiff ct (ca, cb) (a, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
differentiableDiv ::
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) ->
RevDiff a b b
differentiableDiv :: forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (MkRevDiff (b
x0, b
x1) (b, b) -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b
x0 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1) (\b
cy -> (b, b) -> a
bpc (b
cy b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1, b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
x0 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1 b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x1 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
differentiableRecip ::
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableRecip :: forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip (MkRevDiff b
x b -> a
bpc) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
r (b -> a
bpc (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
r b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))
where
r :: b
r = b -> b
forall a. Divisive a => a -> a
recip b
x
instance
(Additive a, Divisive b, Subtractive b, IntegerPower b) =>
Divisive (RevDiff a b b)
where
recip :: RevDiff a b b -> RevDiff a b b
recip = RevDiff a b b -> RevDiff a b b
forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip
/ :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(/) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
differentiablePow ::
(ExpField b) =>
RevDiff a (b, b) (b, b) ->
RevDiff a b b
differentiablePow :: forall b a. ExpField b => RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiablePow (MkRevDiff (b
x, b
p) (b, b) -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
xp (\b
cy -> (b, b) -> a
bpc (b
p b -> b -> b
forall a. Multiplicative a => a -> a -> a
* (b
x b -> b -> b
forall a. ExpField a => a -> a -> a
** (b
p b -> b -> b
forall a. Subtractive a => a -> a -> a
- b
forall a. Multiplicative a => a
one)) b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. ExpField a => a -> a
log b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
xp b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
where
xp :: b
xp = b
x b -> b -> b
forall a. ExpField a => a -> a -> a
** b
p
differentiableExp ::
(ExpField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableExp :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableExp (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (b -> a
bp (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
y b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))
where
y :: b
y = b -> b
forall a. ExpField a => a -> a
exp b
x
differentiableLog ::
(ExpField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableLog :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableLog (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
forall a. ExpField a => a -> a
log b
x) (b -> a
bp (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
x))
differentiableLogBase ::
(ExpField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) ->
RevDiff a b b
differentiableLogBase :: forall b a.
(ExpField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableLogBase (MkRevDiff (b
b, b
x) (b, b) -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
(b
logX b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
logB)
(\b
cy -> (b, b) -> a
bpc (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
logX b -> b -> b
forall a. Divisive a => a -> a -> a
/ (b
logB b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2) b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
b b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. Divisive a => a -> a
recip b
x b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
logB b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
where
logX :: b
logX = b -> b
forall a. ExpField a => a -> a
log b
x
logB :: b
logB = b -> b
forall a. ExpField a => a -> a
log b
b
differentiableSqrt ::
(ExpField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableSqrt :: forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableSqrt (MkRevDiff b
x b -> a
bp) = b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff b
y (\b
cy -> b -> a
bp (b -> a) -> b -> a
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Divisive a => a -> a
recip (b
forall a. (Multiplicative a, Additive a) => a
two b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
y) b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy)
where
y :: b
y = b -> b
forall a. ExpField a => a -> a
sqrt b
x
instance
(ExpField b, Additive a, Subtractive a, IntegerPower b) =>
ExpField (RevDiff a b b)
where
exp :: RevDiff a b b -> RevDiff a b b
exp = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableExp
log :: RevDiff a b b -> RevDiff a b b
log = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableLog
** :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(**) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a. ExpField b => RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiablePow (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
logBase :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
logBase = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(ExpField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableLogBase (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
sqrt :: RevDiff a b b -> RevDiff a b b
sqrt = RevDiff a b b -> RevDiff a b b
forall b a. ExpField b => RevDiff a b b -> RevDiff a b b
differentiableSqrt
differentiableAtan2 ::
(TrigField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) ->
RevDiff a b b
differentiableAtan2 :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableAtan2 (MkRevDiff (b
y, b
x) (b, b) -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff
(b -> b -> b
forall a. TrigField a => a -> a -> a
atan2 b
y b
x)
(\b
cy -> (b, b) -> a
bpc (b
x b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
r2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy, b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> b -> b
forall a b. (a -> b) -> a -> b
$ b
y b -> b -> b
forall a. Divisive a => a -> a -> a
/ b
r2 b -> b -> b
forall a. Multiplicative a => a -> a -> a
* b
cy))
where
r2 :: b
r2 = b
x b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2 b -> b -> b
forall a. Additive a => a -> a -> a
+ b
y b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2
instance
( AlgebraicPower Int a,
MultiplicativeAction Int a,
Multiplicative a
) =>
AlgebraicPower Int (RevDiff c a a)
where
RevDiff c a a
x ^^ :: RevDiff c a a -> Int -> RevDiff c a a
^^ Int
n = RevDiff c a a -> RevDiff c a a
f RevDiff c a a
x
where
f :: RevDiff c a a -> RevDiff c a a
f =
(a -> a) -> (a -> a) -> RevDiff c a a -> RevDiff c a a
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc
(a -> Int -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ Int
n)
(\a
x' -> Int
n Int -> a -> a
forall a b. MultiplicativeAction a b => a -> b -> b
*| (a
x' a -> Int -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ (Int
n Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1)))
instance
( AlgebraicPower Integer a,
MultiplicativeAction Integer a,
Multiplicative a
) =>
AlgebraicPower Integer (RevDiff c a a)
where
RevDiff c a a
x ^^ :: RevDiff c a a -> Integer -> RevDiff c a a
^^ Integer
n = RevDiff c a a -> RevDiff c a a
f RevDiff c a a
x
where
f :: RevDiff c a a -> RevDiff c a a
f =
(a -> a) -> (a -> a) -> RevDiff c a a -> RevDiff c a a
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc
(a -> Integer -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ Integer
n)
(\a
x' -> Integer
n Integer -> a -> a
forall a b. MultiplicativeAction a b => a -> b -> b
*| (a
x' a -> Integer -> a
forall a b. AlgebraicPower a b => b -> a -> b
^^ (Integer
n Integer -> Integer -> Integer
forall a. Subtractive a => a -> a -> a
- Integer
1)))
differentiableSin ::
(TrigField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableSin :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSin = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
sin b -> b
forall a. TrigField a => a -> a
cos
differentiableCos ::
(TrigField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableCos :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCos = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
cos (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
sin)
differentiableTan ::
(TrigField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableTan :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTan = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
tan ((b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ (-Integer
2)) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
cos)
differentiableAsin ::
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAsin :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsin = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
asin (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
differentiableAcos ::
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAcos :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcos =
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
acos (b -> b
forall a. Subtractive a => a -> a
negate (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
differentiableAtan ::
(TrigField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAtan :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtan = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
atan (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Additive a => a -> a -> a
+) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
differentiableSinh ::
(TrigField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableSinh :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSinh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
sinh b -> b
forall a. TrigField a => a -> a
cosh
differentiableCosh ::
(TrigField b) =>
RevDiff a b b ->
RevDiff a b b
differentiableCosh :: forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCosh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
cosh b -> b
forall a. TrigField a => a -> a
sinh
differentiableTanh ::
(TrigField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableTanh :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTanh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
tanh ((b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ (-Integer
2)) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. TrigField a => a -> a
cosh)
differentiableAsinh ::
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAsinh :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsinh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
asinh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Additive a => a -> a -> a
+) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
differentiableAcosh ::
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAcosh :: forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcosh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
acosh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> b
forall a. ExpField a => a -> a
sqrt (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
differentiableAtanh ::
(TrigField b, IntegerPower b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAtanh :: forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtanh = (b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
forall b a.
Multiplicative b =>
(b -> b) -> (b -> b) -> RevDiff a b b -> RevDiff a b b
simpleDifferentiableFunc b -> b
forall a. TrigField a => a -> a
atanh (b -> b
forall a. Divisive a => a -> a
recip (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b
forall a. Multiplicative a => a
one b -> b -> b
forall a. Subtractive a => a -> a -> a
-) (b -> b) -> (b -> b) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> Integer -> b
forall a. AlgebraicPower Integer a => a -> Integer -> a
^ Integer
2))
instance
(Additive a, Subtractive a, ExpField b, TrigField b, IntegerPower b) =>
TrigField (RevDiff a b b)
where
pi :: RevDiff a b b
pi = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff b
forall a. TrigField a => a
pi
sin :: RevDiff a b b -> RevDiff a b b
sin = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSin
cos :: RevDiff a b b -> RevDiff a b b
cos = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCos
tan :: RevDiff a b b -> RevDiff a b b
tan = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTan
asin :: RevDiff a b b -> RevDiff a b b
asin = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsin
acos :: RevDiff a b b -> RevDiff a b b
acos = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcos
atan :: RevDiff a b b -> RevDiff a b b
atan = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtan
atan2 :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
atan2 = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableAtan2 (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
sinh :: RevDiff a b b -> RevDiff a b b
sinh = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableSinh
cosh :: RevDiff a b b -> RevDiff a b b
cosh = RevDiff a b b -> RevDiff a b b
forall b a. TrigField b => RevDiff a b b -> RevDiff a b b
differentiableCosh
tanh :: RevDiff a b b -> RevDiff a b b
tanh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableTanh
asinh :: RevDiff a b b -> RevDiff a b b
asinh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAsinh
acosh :: RevDiff a b b -> RevDiff a b b
acosh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, ExpField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAcosh
atanh :: RevDiff a b b -> RevDiff a b b
atanh = RevDiff a b b -> RevDiff a b b
forall b a.
(TrigField b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableAtanh
differentiableAbs ::
(GHCN.Num b, Multiplicative b) =>
RevDiff a b b ->
RevDiff a b b
differentiableAbs :: forall b a.
(Num b, Multiplicative b) =>
RevDiff a b b -> RevDiff a b b
differentiableAbs (MkRevDiff b
x b -> a
bpc) =
b -> (b -> a) -> RevDiff a b b
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (b -> b
forall a. Num a => a -> a
GHCN.abs b
x) (b -> a
bpc (b -> a) -> (b -> b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> b
forall a. Num a => a -> a
GHCN.signum b
x b -> b -> b
forall a. Multiplicative a => a -> a -> a
*))
differentiableSign ::
(Additive a, GHCN.Num b) =>
RevDiff a b b ->
RevDiff a b b
differentiableSign :: forall a b. (Additive a, Num b) => RevDiff a b b -> RevDiff a b b
differentiableSign (MkRevDiff b
x b -> a
_) = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b) -> b -> RevDiff a b b
forall a b. (a -> b) -> a -> b
$ b -> b
forall a. Num a => a -> a
GHCN.signum b
x
instance
( Additive a,
Subtractive a,
GHCN.Num b,
Subtractive b,
Multiplicative b
) =>
GHCN.Num (RevDiff a b b)
where
+ :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(+) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.+)
(-) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.-)
* :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(*) = RevDiff a b b -> RevDiff a b b -> RevDiff a b b
forall a. Num a => a -> a -> a
(GHCN.*)
negate :: RevDiff a b b -> RevDiff a b b
negate = RevDiff a b b -> RevDiff a b b
forall a c b.
(Subtractive a, Subtractive c) =>
RevDiff a b c -> RevDiff a b c
differentiableNegate
abs :: RevDiff a b b -> RevDiff a b b
abs = RevDiff a b b -> RevDiff a b b
forall b a.
(Num b, Multiplicative b) =>
RevDiff a b b -> RevDiff a b b
differentiableAbs
signum :: RevDiff a b b -> RevDiff a b b
signum = RevDiff a b b -> RevDiff a b b
forall a b. (Additive a, Num b) => RevDiff a b b -> RevDiff a b b
differentiableSign
fromInteger :: Integer -> RevDiff a b b
fromInteger = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b) -> (Integer -> b) -> Integer -> RevDiff a b b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> b
forall a. Num a => Integer -> a
GHCN.fromInteger
instance
(FromInteger c, Additive a) =>
FromInteger (RevDiff a b c)
where
fromInteger :: Integer -> RevDiff a b c
fromInteger = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Integer -> c) -> Integer -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> c
forall a. FromInteger a => Integer -> a
fromInteger
instance
(FromIntegral c Int8, Additive a) =>
FromIntegral (RevDiff a b c) Int8
where
fromIntegral :: Int8 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int8 -> c) -> Int8 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int8 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Int16, Additive a) =>
FromIntegral (RevDiff a b c) Int16
where
fromIntegral :: Int16 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int16 -> c) -> Int16 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int16 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Int32, Additive a) =>
FromIntegral (RevDiff a b c) Int32
where
fromIntegral :: Int32 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int32 -> c) -> Int32 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Int64, Additive a) =>
FromIntegral (RevDiff a b c) Int64
where
fromIntegral :: Int64 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int64 -> c) -> Int64 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Int, Additive a) =>
FromIntegral (RevDiff a b c) Int
where
fromIntegral :: Int -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Int -> c) -> Int -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Word8, Additive a) =>
FromIntegral (RevDiff a b c) Word8
where
fromIntegral :: Word8 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word8 -> c) -> Word8 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Word16, Additive a) =>
FromIntegral (RevDiff a b c) Word16
where
fromIntegral :: Word16 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word16 -> c) -> Word16 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Word32, Additive a) =>
FromIntegral (RevDiff a b c) Word32
where
fromIntegral :: Word32 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word32 -> c) -> Word32 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Word64, Additive a) =>
FromIntegral (RevDiff a b c) Word64
where
fromIntegral :: Word64 -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word64 -> c) -> Word64 -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Word, Additive a) =>
FromIntegral (RevDiff a b c) Word
where
fromIntegral :: Word -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Word -> c) -> Word -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Integer, Additive a) =>
FromIntegral (RevDiff a b c) Integer
where
fromIntegral :: Integer -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Integer -> c) -> Integer -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
(FromIntegral c Natural, Additive a) =>
FromIntegral (RevDiff a b c) Natural
where
fromIntegral :: Natural -> RevDiff a b c
fromIntegral = c -> RevDiff a b c
forall a c b. Additive a => c -> RevDiff a b c
constDiff (c -> RevDiff a b c) -> (Natural -> c) -> Natural -> RevDiff a b c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Natural -> c
forall a b. FromIntegral a b => b -> a
NumHask.fromIntegral
instance
( Additive a,
Subtractive a,
Subtractive b,
Divisive b,
GHCR.Fractional b,
IntegerPower b
) =>
GHCR.Fractional (RevDiff a b b)
where
/ :: RevDiff a b b -> RevDiff a b b -> RevDiff a b b
(/) = RevDiff a (b, b) (b, b) -> RevDiff a b b
forall b a.
(Subtractive b, Divisive b) =>
RevDiff a (b, b) (b, b) -> RevDiff a b b
differentiableDiv (RevDiff a (b, b) (b, b) -> RevDiff a b b)
-> (RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b))
-> RevDiff a b b
-> RevDiff a b b
-> RevDiff a b b
forall c d a b. (c -> d) -> (a -> b -> c) -> a -> b -> d
.: RevDiff a b b -> RevDiff a b b -> RevDiff a (b, b) (b, b)
forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple
recip :: RevDiff a b b -> RevDiff a b b
recip = RevDiff a b b -> RevDiff a b b
forall b a.
(Divisive b, Subtractive b, IntegerPower b) =>
RevDiff a b b -> RevDiff a b b
differentiableRecip
fromRational :: Rational -> RevDiff a b b
fromRational = b -> RevDiff a b b
forall a c b. Additive a => c -> RevDiff a b c
constDiff (b -> RevDiff a b b)
-> (Rational -> b) -> Rational -> RevDiff a b b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> b
forall a. Fractional a => Rational -> a
GHCR.fromRational
twoArgsToTuple ::
(Additive a) =>
RevDiff a b0 c0 ->
RevDiff a b1 c1 ->
RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple :: forall a b0 c0 b1 c1.
Additive a =>
RevDiff a b0 c0 -> RevDiff a b1 c1 -> RevDiff a (b0, b1) (c0, c1)
twoArgsToTuple (MkRevDiff c0
x0 b0 -> a
bpc0) (MkRevDiff c1
x1 b1 -> a
bpc1) =
(c0, c1) -> ((b0, b1) -> a) -> RevDiff a (b0, b1) (c0, c1)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c0
x0, c1
x1) (\(b0
cy0, b1
cy1) -> b0 -> a
bpc0 b0
cy0 a -> a -> a
forall a. Additive a => a -> a -> a
+ b1 -> a
bpc1 b1
cy1)
tupleArg ::
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) ->
(RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg :: forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg (MkRevDiff (c0
x0, c1
x1) (b0, b1) -> a
bpc) =
( c0 -> (b0 -> a) -> RevDiff a b0 c0
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c0
x0 (\b0
cy -> (b0, b1) -> a
bpc (b0
cy, b1
forall a. Additive a => a
zero)),
c1 -> (b1 -> a) -> RevDiff a b1 c1
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c1
x1 (\b1
cy -> (b0, b1) -> a
bpc (b0
forall a. Additive a => a
zero, b1
cy))
)
mkTupleArg ::
(Additive b0, Additive b1) =>
RevDiffArg a b0 c0 d0 ->
RevDiffArg a b1 c1 d1 ->
RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg :: forall b0 b1 a c0 d0 c1 d1.
(Additive b0, Additive b1) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1 -> RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 = RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> (RevDiff a b0 c0, RevDiff a b1 c1)
-> (d0, d1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
cross RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 ((RevDiff a b0 c0, RevDiff a b1 c1) -> (d0, d1))
-> (RevDiff a (b0, b1) (c0, c1)
-> (RevDiff a b0 c0, RevDiff a b1 c1))
-> RevDiff a (b0, b1) (c0, c1)
-> (d0, d1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg
instance
( AutoDifferentiableArgument a0,
AutoDifferentiableArgument a1,
DerivativeRoot a0 ~ DerivativeRoot a1
) =>
AutoDifferentiableArgument (a0, a1)
where
type DerivativeRoot (a0, a1) = DerivativeRoot a0
type DerivativeCoarg (a0, a1) = (DerivativeCoarg a0, DerivativeCoarg a1)
type DerivativeArg (a0, a1) = (DerivativeArg a0, DerivativeArg a1)
autoArg :: RevDiff (DerivativeRoot a0) (DerivativeCoarg a0, DerivativeCoarg a1) (DerivativeArg a0, DerivativeArg a1) -> (a0, a1)
autoArg :: RevDiff
(DerivativeRoot a0)
(DerivativeCoarg a0, DerivativeCoarg a1)
(DerivativeArg a0, DerivativeArg a1)
-> (a0, a1)
autoArg = RevDiffArg
(DerivativeRoot a1) (DerivativeCoarg a0) (DerivativeArg a0) a0
-> RevDiffArg
(DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1) a1
-> RevDiffArg
(DerivativeRoot a1)
(DerivativeCoarg a0, DerivativeCoarg a1)
(DerivativeArg a0, DerivativeArg a1)
(a0, a1)
forall b0 b1 a c0 d0 c1 d1.
(Additive b0, Additive b1) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1 -> RevDiffArg a (b0, b1) (c0, c1) (d0, d1)
mkTupleArg RevDiff (DerivativeRoot a0) (DerivativeCoarg a0) (DerivativeArg a0)
-> a0
RevDiffArg
(DerivativeRoot a1) (DerivativeCoarg a0) (DerivativeArg a0) a0
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiffArg
(DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1) a1
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
mkTupleVal :: (a0 -> b0) -> (a1 -> b1) -> (a0, a1) -> (b0, b1)
mkTupleVal :: forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal = (a0 -> b0) -> (a1 -> b1) -> (a0, a1) -> (b0, b1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
cross
tupleVal ::
(Multiplicative b0, Multiplicative b1) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1) ->
(a0, a1)
tupleVal :: forall b0 b1 a0 c0 a1 c1.
(Multiplicative b0, Multiplicative b1) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1) -> (a0, a1)
tupleVal = (RevDiff a0 b0 c0 -> a0)
-> (RevDiff a1 b1 c1 -> a1)
-> (RevDiff a0 b0 c0, RevDiff a1 b1 c1)
-> (a0, a1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal RevDiff a0 b0 c0 -> a0
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a1 b1 c1 -> a1
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance
(AutoDifferentiableValue a0, AutoDifferentiableValue a1) =>
AutoDifferentiableValue (a0, a1)
where
type DerivativeValue (a0, a1) = (DerivativeValue a0, DerivativeValue a1)
autoVal :: (a0, a1) -> (DerivativeValue a0, DerivativeValue a1)
autoVal :: (a0, a1) -> (DerivativeValue a0, DerivativeValue a1)
autoVal = (a0 -> DerivativeValue a0)
-> (a1 -> DerivativeValue a1)
-> (a0, a1)
-> (DerivativeValue a0, DerivativeValue a1)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
mkTupleVal a0 -> DerivativeValue a0
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a1 -> DerivativeValue a1
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
tupleArgDerivative ::
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' (a0, a1) a0, RevDiff' (a0, a1) a1) -> b) ->
(a0, a1) ->
DerivativeValue b
tupleArgDerivative :: forall a0 a1 b.
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' (a0, a1) a0, RevDiff' (a0, a1) a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleArgDerivative = (RevDiff
(Dual (Tangent (a0, a1))) (Dual (Tangent (a0, a1))) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1))
-> ((RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
-> b)
-> (a0, a1)
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
RevDiff
(Dual (Tangent (a0, a1))) (Dual (Tangent (a0, a1))) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg
tupleDerivativeOverX ::
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1) -> b) ->
(a0, a1) ->
DerivativeValue b
tupleDerivativeOverX :: forall b a0 a1.
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleDerivativeOverX (RevDiff' a0 a0, RevDiff' a0 a1) -> b
f (a0
x0, a1
x1) =
(RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a0 a0
x -> (RevDiff' a0 a0, RevDiff' a0 a1) -> b
f (RevDiff' a0 a0
x, a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)) a0
x0
tupleDerivativeOverY ::
(Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' a1 a0, RevDiff' a1 a1) -> b) ->
(a0, a1) ->
DerivativeValue b
tupleDerivativeOverY :: forall a1 b a0.
(Additive (CT a1), AutoDifferentiableValue b) =>
((RevDiff' a1 a0, RevDiff' a1 a1) -> b)
-> (a0, a1) -> DerivativeValue b
tupleDerivativeOverY (RevDiff' a1 a0, RevDiff' a1 a1) -> b
f (a0
x0, a1
x1) =
(RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a1 a1
x -> (RevDiff' a1 a0, RevDiff' a1 a1) -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, RevDiff' a1 a1
x)) a1
x1
twoArgsDerivative ::
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b) ->
a0 ->
a1 ->
DerivativeValue b
twoArgsDerivative :: forall a0 a1 b.
(Additive (CT a0), Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivative RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b
f = ((a0, a1) -> DerivativeValue b) -> a0 -> a1 -> DerivativeValue b
forall a b c. ((a, b) -> c) -> a -> b -> c
curry ((RevDiff' (a0, a1) (a0, a1) -> b) -> (a0, a1) -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' (a0, a1) (a0, a1) -> b)
-> (a0, a1) -> DerivativeValue b)
-> (RevDiff' (a0, a1) (a0, a1) -> b)
-> (a0, a1)
-> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ (RevDiff (CT a0, CT a1) (CT a0) a0
-> RevDiff (CT a0, CT a1) (CT a1) a1 -> b)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
-> b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry RevDiff (CT a0, CT a1) (CT a0) a0
-> RevDiff (CT a0, CT a1) (CT a1) a1 -> b
RevDiff' (a0, a1) a0 -> RevDiff' (a0, a1) a1 -> b
f ((RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
-> b)
-> (RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1))
-> RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a0, CT a1) (CT a0, CT a1) (a0, a1)
-> (RevDiff (CT a0, CT a1) (CT a0) a0,
RevDiff (CT a0, CT a1) (CT a1) a1)
forall b0 b1 a c0 c1.
(Additive b0, Additive b1) =>
RevDiff a (b0, b1) (c0, c1) -> (RevDiff a b0 c0, RevDiff a b1 c1)
tupleArg)
twoArgsDerivativeOverX ::
(Additive (CT a0), AutoDifferentiableValue b) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> b) ->
a0 ->
a1 ->
DerivativeValue b
twoArgsDerivativeOverX :: forall a0 b a1.
(Additive (CT a0), AutoDifferentiableValue b) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivativeOverX RevDiff' a0 a0 -> RevDiff' a0 a1 -> b
f a0
x0 a1
x1 =
(RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative (\RevDiff' a0 a0
x -> RevDiff' a0 a0 -> RevDiff' a0 a1 -> b
f RevDiff' a0 a0
x (a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)) a0
x0
twoArgsDerivativeOverY ::
(Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> b) ->
a0 ->
a1 ->
DerivativeValue b
twoArgsDerivativeOverY :: forall a1 b a0.
(Additive (CT a1), AutoDifferentiableValue b) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> b)
-> a0 -> a1 -> DerivativeValue b
twoArgsDerivativeOverY RevDiff' a1 a0 -> RevDiff' a1 a1 -> b
f = (RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b)
-> (a0 -> RevDiff' a1 a1 -> b) -> a0 -> a1 -> DerivativeValue b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff' a1 a0 -> RevDiff' a1 a1 -> b
f (RevDiff' a1 a0 -> RevDiff' a1 a1 -> b)
-> (a0 -> RevDiff' a1 a0) -> a0 -> RevDiff' a1 a1 -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff
tupleValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c0,
Multiplicative c1,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1)) ->
DerivativeArg a ->
(b0, b1)
tupleValDerivative :: forall a c0 c1 b0 d0 b1 d1.
(AutoDifferentiableArgument a, Multiplicative c0,
Multiplicative c1, DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1))
-> DerivativeArg a -> (b0, b1)
tupleValDerivative = ((RevDiff b0 c0 d0, RevDiff b1 c1 d1) -> (b0, b1))
-> (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1))
-> DerivativeArg a
-> (b0, b1)
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative (RevDiff b0 c0 d0, RevDiff b1 c1 d1) -> (b0, b1)
forall b0 b1 a0 c0 a1 c1.
(Multiplicative b0, Multiplicative b1) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1) -> (a0, a1)
tupleVal
tripleDerivativeOverX ::
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b) ->
(a0, a1, a2) ->
DerivativeValue b
tripleDerivativeOverX :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0)) =>
((RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverX (RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
(RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
(\RevDiff' a0 a0
x -> (RevDiff' a0 a0, RevDiff' a0 a1, RevDiff' a0 a2) -> b
f (RevDiff' a0 a0
x, a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1, a2 -> RevDiff' a0 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
a0
x0
tripleDerivativeOverY ::
(AutoDifferentiableValue b, Additive (CT a1)) =>
((RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b) ->
(a0, a1, a2) ->
DerivativeValue b
tripleDerivativeOverY :: forall b a1 a0 a2.
(AutoDifferentiableValue b, Additive (CT a1)) =>
((RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverY (RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
(RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
(\RevDiff' a1 a1
x -> (RevDiff' a1 a0, RevDiff' a1 a1, RevDiff' a1 a2) -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, RevDiff' a1 a1
x, a2 -> RevDiff' a1 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
a1
x1
tripleDerivativeOverZ ::
(AutoDifferentiableValue b, Additive (CT a2)) =>
((RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b) ->
(a0, a1, a2) ->
DerivativeValue b
tripleDerivativeOverZ :: forall b a2 a0 a1.
(AutoDifferentiableValue b, Additive (CT a2)) =>
((RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleDerivativeOverZ (RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b
f (a0
x0, a1
x1, a2
x2) =
(RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
(\RevDiff' a2 a2
x -> (RevDiff' a2 a0, RevDiff' a2 a1, RevDiff' a2 a2) -> b
f (a0 -> RevDiff' a2 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0, a1 -> RevDiff' a2 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1, RevDiff' a2 a2
x))
a2
x2
threeArgsToTriple ::
(Additive a) =>
RevDiff a b0 c0 ->
RevDiff a b1 c1 ->
RevDiff a b2 c2 ->
RevDiff a (b0, b1, b2) (c0, c1, c2)
threeArgsToTriple :: forall a b0 c0 b1 c1 b2 c2.
Additive a =>
RevDiff a b0 c0
-> RevDiff a b1 c1
-> RevDiff a b2 c2
-> RevDiff a (b0, b1, b2) (c0, c1, c2)
threeArgsToTriple (MkRevDiff c0
x0 b0 -> a
bpc0) (MkRevDiff c1
x1 b1 -> a
bpc1) (MkRevDiff c2
x2 b2 -> a
bpc2) =
(c0, c1, c2)
-> ((b0, b1, b2) -> a) -> RevDiff a (b0, b1, b2) (c0, c1, c2)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (c0
x0, c1
x1, c2
x2) (\(b0
cy0, b1
cy1, b2
cy2) -> b0 -> a
bpc0 b0
cy0 a -> a -> a
forall a. Additive a => a -> a -> a
+ b1 -> a
bpc1 b1
cy1 a -> a -> a
forall a. Additive a => a -> a -> a
+ b2 -> a
bpc2 b2
cy2)
tripleArg ::
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2) ->
(RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg :: forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg (MkRevDiff (c0
x0, c1
x1, c2
x2) (b0, b1, b2) -> a
bpc) =
( c0 -> (b0 -> a) -> RevDiff a b0 c0
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c0
x0 (\b0
cx -> (b0, b1, b2) -> a
bpc (b0
cx, b1
forall a. Additive a => a
zero, b2
forall a. Additive a => a
zero)),
c1 -> (b1 -> a) -> RevDiff a b1 c1
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c1
x1 (\b1
cy -> (b0, b1, b2) -> a
bpc (b0
forall a. Additive a => a
zero, b1
cy, b2
forall a. Additive a => a
zero)),
c2 -> (b2 -> a) -> RevDiff a b2 c2
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c2
x2 (\b2
cz -> (b0, b1, b2) -> a
bpc (b0
forall a. Additive a => a
zero, b1
forall a. Additive a => a
zero, b2
cz))
)
mkTripleArg ::
(Additive b0, Additive b1, Additive b2) =>
RevDiffArg a b0 c0 d0 ->
RevDiffArg a b1 c1 d1 ->
RevDiffArg a b2 c2 d2 ->
RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg :: forall b0 b1 b2 a c0 d0 c1 d1 c2 d2.
(Additive b0, Additive b1, Additive b2) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 RevDiffArg a b2 c2 d2
f2 = RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
-> (d0, d1, d2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
cross3 RevDiffArg a b0 c0 d0
f0 RevDiffArg a b1 c1 d1
f1 RevDiffArg a b2 c2 d2
f2 ((RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
-> (d0, d1, d2))
-> (RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2))
-> RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (d0, d1, d2)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg
instance
( AutoDifferentiableArgument a0,
AutoDifferentiableArgument a1,
AutoDifferentiableArgument a2,
DerivativeRoot a0 ~ DerivativeRoot a1,
DerivativeRoot a0 ~ DerivativeRoot a2
) =>
AutoDifferentiableArgument (a0, a1, a2)
where
type DerivativeRoot (a0, a1, a2) = DerivativeRoot a0
type DerivativeCoarg (a0, a1, a2) = (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
type DerivativeArg (a0, a1, a2) = (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
autoArg :: RevDiff (DerivativeRoot a0) (DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2) (DerivativeArg a0, DerivativeArg a1, DerivativeArg a2) -> (a0, a1, a2)
autoArg :: RevDiff
(DerivativeRoot a0)
(DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
(DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
-> (a0, a1, a2)
autoArg = RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a0) (DerivativeArg a0) a0
-> RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a1) (DerivativeArg a1) a1
-> RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a2) (DerivativeArg a2) a2
-> RevDiffArg
(DerivativeRoot a2)
(DerivativeCoarg a0, DerivativeCoarg a1, DerivativeCoarg a2)
(DerivativeArg a0, DerivativeArg a1, DerivativeArg a2)
(a0, a1, a2)
forall b0 b1 b2 a c0 d0 c1 d1 c2 d2.
(Additive b0, Additive b1, Additive b2) =>
RevDiffArg a b0 c0 d0
-> RevDiffArg a b1 c1 d1
-> RevDiffArg a b2 c2 d2
-> RevDiffArg a (b0, b1, b2) (c0, c1, c2) (d0, d1, d2)
mkTripleArg RevDiff (DerivativeRoot a0) (DerivativeCoarg a0) (DerivativeArg a0)
-> a0
RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a0) (DerivativeArg a0) a0
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiff (DerivativeRoot a1) (DerivativeCoarg a1) (DerivativeArg a1)
-> a1
RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a1) (DerivativeArg a1) a1
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg RevDiffArg
(DerivativeRoot a2) (DerivativeCoarg a2) (DerivativeArg a2) a2
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
mkTripleVal :: (a0 -> b0) -> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal :: forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal = (a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
cross3
tripleVal ::
(Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2) ->
(a0, a1, a2)
tripleVal :: forall b0 b1 b2 a0 c0 a1 c1 a2 c2.
(Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
tripleVal = (RevDiff a0 b0 c0 -> a0)
-> (RevDiff a1 b1 c1 -> a1)
-> (RevDiff a2 b2 c2 -> a2)
-> (RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal RevDiff a0 b0 c0 -> a0
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a1 b1 c1 -> a1
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal RevDiff a2 b2 c2 -> a2
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance
( AutoDifferentiableValue a0,
AutoDifferentiableValue a1,
AutoDifferentiableValue a2
) =>
AutoDifferentiableValue (a0, a1, a2)
where
type DerivativeValue (a0, a1, a2) = (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
autoVal :: (a0, a1, a2) -> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
autoVal :: (a0, a1, a2)
-> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
autoVal = (a0 -> DerivativeValue a0)
-> (a1 -> DerivativeValue a1)
-> (a2 -> DerivativeValue a2)
-> (a0, a1, a2)
-> (DerivativeValue a0, DerivativeValue a1, DerivativeValue a2)
forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
mkTripleVal a0 -> DerivativeValue a0
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a1 -> DerivativeValue a1
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal a2 -> DerivativeValue a2
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
tripleArgDerivative ::
( Additive (CT a0),
Additive (CT a1),
Additive (CT a2),
AutoDifferentiableValue b
) =>
( ( RevDiff' (a0, a1, a2) a0,
RevDiff' (a0, a1, a2) a1,
RevDiff' (a0, a1, a2) a2
) ->
b
) ->
(a0, a1, a2) ->
DerivativeValue b
tripleArgDerivative :: forall a0 a1 a2 b.
(Additive (CT a0), Additive (CT a1), Additive (CT a2),
AutoDifferentiableValue b) =>
((RevDiff' (a0, a1, a2) a0, RevDiff' (a0, a1, a2) a1,
RevDiff' (a0, a1, a2) a2)
-> b)
-> (a0, a1, a2) -> DerivativeValue b
tripleArgDerivative = (RevDiff
(Dual (Tangent (a0, a1, a2)))
(Dual (Tangent (a0, a1, a2)))
(a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2))
-> ((RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
-> b)
-> (a0, a1, a2)
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
RevDiff
(Dual (Tangent (a0, a1, a2)))
(Dual (Tangent (a0, a1, a2)))
(a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg
threeArgsDerivative ::
( AutoDifferentiableValue b,
Additive (CT a0),
Additive (CT a1),
Additive (CT a2)
) =>
( RevDiff' (a0, a1, a2) a0 ->
RevDiff' (a0, a1, a2) a1 ->
RevDiff' (a0, a1, a2) a2 ->
b
) ->
a0 ->
a1 ->
a2 ->
DerivativeValue b
threeArgsDerivative :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0), Additive (CT a1),
Additive (CT a2)) =>
(RevDiff' (a0, a1, a2) a0
-> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
threeArgsDerivative RevDiff' (a0, a1, a2) a0
-> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b
f = ((a0, a1, a2) -> DerivativeValue b)
-> a0 -> a1 -> a2 -> DerivativeValue b
forall a b c d. ((a, b, c) -> d) -> a -> b -> c -> d
curry3 ((RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
-> (a0, a1, a2) -> DerivativeValue b)
-> (RevDiff' (a0, a1, a2) (a0, a1, a2) -> b)
-> (a0, a1, a2)
-> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0
-> RevDiff (CT a0, CT a1, CT a2) (CT a1) a1
-> RevDiff (CT a0, CT a1, CT a2) (CT a2) a2
-> b)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
-> b
forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 RevDiff (CT a0, CT a1, CT a2) (CT a0) a0
-> RevDiff (CT a0, CT a1, CT a2) (CT a1) a1
-> RevDiff (CT a0, CT a1, CT a2) (CT a2) a2
-> b
RevDiff' (a0, a1, a2) a0
-> RevDiff' (a0, a1, a2) a1 -> RevDiff' (a0, a1, a2) a2 -> b
f ((RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
-> b)
-> (RevDiff
(CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2))
-> RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff (CT a0, CT a1, CT a2) (CT a0, CT a1, CT a2) (a0, a1, a2)
-> (RevDiff (CT a0, CT a1, CT a2) (CT a0) a0,
RevDiff (CT a0, CT a1, CT a2) (CT a1) a1,
RevDiff (CT a0, CT a1, CT a2) (CT a2) a2)
forall b0 b1 b2 a c0 c1 c2.
(Additive b0, Additive b1, Additive b2) =>
RevDiff a (b0, b1, b2) (c0, c1, c2)
-> (RevDiff a b0 c0, RevDiff a b1 c1, RevDiff a b2 c2)
tripleArg)
derivative3ArgsOverX ::
(AutoDifferentiableValue b, Additive (CT a0)) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b) ->
a0 ->
a1 ->
a2 ->
DerivativeValue b
derivative3ArgsOverX :: forall b a0 a1 a2.
(AutoDifferentiableValue b, Additive (CT a0)) =>
(RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverX RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b
f a0
x0 a1
x1 a2
x2 =
(RevDiff' a0 a0 -> b) -> a0 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
(\RevDiff' a0 a0
x0' -> RevDiff' a0 a0 -> RevDiff' a0 a1 -> RevDiff' a0 a2 -> b
f RevDiff' a0 a0
x0' (a1 -> RevDiff' a0 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1) (a2 -> RevDiff' a0 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
a0
x0
derivative3ArgsOverY ::
(AutoDifferentiableValue b, Additive (CT a1)) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b) ->
a0 ->
a1 ->
a2 ->
DerivativeValue b
derivative3ArgsOverY :: forall b a1 a0 a2.
(AutoDifferentiableValue b, Additive (CT a1)) =>
(RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverY RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b
f a0
x0 a1
x1 a2
x2 =
(RevDiff' a1 a1 -> b) -> a1 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative
(\RevDiff' a1 a1
x1' -> RevDiff' a1 a0 -> RevDiff' a1 a1 -> RevDiff' a1 a2 -> b
f (a0 -> RevDiff' a1 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0) RevDiff' a1 a1
x1' (a2 -> RevDiff' a1 a2
forall a c b. Additive a => c -> RevDiff a b c
constDiff a2
x2))
a1
x1
derivative3ArgsOverZ ::
(AutoDifferentiableValue b, Additive (CT a2)) =>
(RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b) ->
a0 ->
a1 ->
a2 ->
DerivativeValue b
derivative3ArgsOverZ :: forall b a2 a0 a1.
(AutoDifferentiableValue b, Additive (CT a2)) =>
(RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b)
-> a0 -> a1 -> a2 -> DerivativeValue b
derivative3ArgsOverZ RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b
f a0
x0 a1
x1 =
(RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall c a.
AutoDifferentiableValue c =>
(RevDiff' a a -> c) -> a -> DerivativeValue c
scalarArgDerivative ((RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b)
-> (RevDiff' a2 a2 -> b) -> a2 -> DerivativeValue b
forall a b. (a -> b) -> a -> b
$ RevDiff' a2 a0 -> RevDiff' a2 a1 -> RevDiff' a2 a2 -> b
f (a0 -> RevDiff' a2 a0
forall a c b. Additive a => c -> RevDiff a b c
constDiff a0
x0) (a1 -> RevDiff' a2 a1
forall a c b. Additive a => c -> RevDiff a b c
constDiff a1
x1)
tripleValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c0,
Multiplicative c1,
Multiplicative c2,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)) ->
DerivativeArg a ->
(b0, b1, b2)
tripleValDerivative :: forall a c0 c1 c2 b0 d0 b1 d1 b2 d2.
(AutoDifferentiableArgument a, Multiplicative c0,
Multiplicative c1, Multiplicative c2,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2))
-> DerivativeArg a -> (b0, b1, b2)
tripleValDerivative = ((RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)
-> (b0, b1, b2))
-> (a -> (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2))
-> DerivativeArg a
-> (b0, b1, b2)
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative (RevDiff b0 c0 d0, RevDiff b1 c1 d1, RevDiff b2 c2 d2)
-> (b0, b1, b2)
forall b0 b1 b2 a0 c0 a1 c1 a2 c2.
(Multiplicative b0, Multiplicative b1, Multiplicative b2) =>
(RevDiff a0 b0 c0, RevDiff a1 b1 c1, RevDiff a2 b2 c2)
-> (a0, a1, a2)
tripleVal
mkBoxedVectorVal :: (a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal :: forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal = (a -> b) -> Vector Vector n a -> Vector Vector n b
forall a b. (a -> b) -> Vector Vector n a -> Vector Vector n b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
instance
(AutoDifferentiableValue a) =>
AutoDifferentiableValue (BoxedVector n a)
where
type DerivativeValue (BoxedVector n a) = BoxedVector n (DerivativeValue a)
autoVal :: BoxedVector n a -> BoxedVector n (DerivativeValue a)
autoVal :: BoxedVector n a -> BoxedVector n (DerivativeValue a)
autoVal = (a -> DerivativeValue a)
-> BoxedVector n a -> BoxedVector n (DerivativeValue a)
forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
boxedVectorVal ::
(Multiplicative b) =>
BoxedVector n (RevDiff a b c) ->
BoxedVector n a
boxedVectorVal :: forall b (n :: Natural) a c.
Multiplicative b =>
BoxedVector n (RevDiff a b c) -> BoxedVector n a
boxedVectorVal = (RevDiff a b c -> a)
-> BoxedVector n (RevDiff a b c) -> BoxedVector n a
forall a b (n :: Natural).
(a -> b) -> BoxedVector n a -> BoxedVector n b
mkBoxedVectorVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
boxedVectorValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> BoxedVector n (RevDiff b c d)) ->
DerivativeArg a ->
BoxedVector n b
boxedVectorValDerivative :: forall a c (n :: Natural) b d.
(AutoDifferentiableArgument a, Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> BoxedVector n (RevDiff b c d))
-> DerivativeArg a -> BoxedVector n b
boxedVectorValDerivative = (BoxedVector n (RevDiff b c d) -> BoxedVector n b)
-> (a -> BoxedVector n (RevDiff b c d))
-> DerivativeArg a
-> BoxedVector n b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative BoxedVector n (RevDiff b c d) -> BoxedVector n b
forall b (n :: Natural) a c.
Multiplicative b =>
BoxedVector n (RevDiff a b c) -> BoxedVector n a
boxedVectorVal
boxedVectorArg ::
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c) ->
BoxedVector n (RevDiff a b c)
boxedVectorArg :: forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg (MkRevDiff BoxedVector n c
array BoxedVector n b -> a
bpc) = (Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c)
forall (v :: * -> *) (n :: Natural) a.
(KnownNat n, Vector v a) =>
(Finite n -> a) -> Vector v n a
DVGS.generate ((Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c))
-> (Finite n -> RevDiff a b c) -> Vector Vector n (RevDiff a b c)
forall a b. (a -> b) -> a -> b
$ \Finite n
k ->
c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff (BoxedVector n c -> Finite n -> c
forall (v :: * -> *) (n :: Natural) a.
Vector v a =>
Vector v n a -> Finite n -> a
DVGS.index BoxedVector n c
array Finite n
k) (BoxedVector n b -> a
bpc (BoxedVector n b -> a) -> (b -> BoxedVector n b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Finite n -> b -> b -> BoxedVector n b
forall (v :: * -> *) a (n :: Natural).
(Vector v a, KnownNat n) =>
Finite n -> a -> a -> Vector v n a
boxedVectorBasis Finite n
k b
forall a. Additive a => a
zero)
mkBoxedVectorArg ::
(Additive b, KnownNat n) =>
RevDiffArg a b c d ->
RevDiffArg a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg :: forall b (n :: Natural) a c d.
(Additive b, KnownNat n) =>
RevDiffArg a b c d
-> RevDiffArg
a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg RevDiffArg a b c d
f = RevDiffArg a b c d
-> Vector Vector n (RevDiff a b c) -> Vector Vector n d
forall a b. (a -> b) -> Vector Vector n a -> Vector Vector n b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiffArg a b c d
f (Vector Vector n (RevDiff a b c) -> Vector Vector n d)
-> (RevDiff a (BoxedVector n b) (BoxedVector n c)
-> Vector Vector n (RevDiff a b c))
-> RevDiff a (BoxedVector n b) (BoxedVector n c)
-> Vector Vector n d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (BoxedVector n b) (BoxedVector n c)
-> Vector Vector n (RevDiff a b c)
forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg
instance
( AutoDifferentiableArgument a,
KnownNat n
) =>
AutoDifferentiableArgument (BoxedVector n a)
where
type DerivativeRoot (BoxedVector n a) = DerivativeRoot a
type DerivativeCoarg (BoxedVector n a) = BoxedVector n (DerivativeCoarg a)
type DerivativeArg (BoxedVector n a) = BoxedVector n (DerivativeArg a)
autoArg :: RevDiff (DerivativeRoot a) (BoxedVector n (DerivativeCoarg a)) (BoxedVector n (DerivativeArg a)) -> BoxedVector n a
autoArg :: RevDiff
(DerivativeRoot a)
(BoxedVector n (DerivativeCoarg a))
(BoxedVector n (DerivativeArg a))
-> BoxedVector n a
autoArg = RevDiffArg
(DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) a
-> RevDiff
(DerivativeRoot a)
(BoxedVector n (DerivativeCoarg a))
(BoxedVector n (DerivativeArg a))
-> BoxedVector n a
forall b (n :: Natural) a c d.
(Additive b, KnownNat n) =>
RevDiffArg a b c d
-> RevDiffArg
a (BoxedVector n b) (BoxedVector n c) (BoxedVector n d)
mkBoxedVectorArg RevDiffArg
(DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a) a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
boxedVectorArgDerivative ::
(KnownNat n, AutoDifferentiableValue b, Additive (CT a)) =>
(BoxedVector n (RevDiff' (BoxedVector n a) a) -> b) ->
BoxedVector n a ->
DerivativeValue b
boxedVectorArgDerivative :: forall (n :: Natural) b a.
(KnownNat n, AutoDifferentiableValue b, Additive (CT a)) =>
(BoxedVector n (RevDiff' (BoxedVector n a) a) -> b)
-> BoxedVector n a -> DerivativeValue b
boxedVectorArgDerivative = (RevDiff
(Dual (Tangent (Vector Vector n a)))
(Dual (Tangent (Vector Vector n a)))
(Vector Vector n a)
-> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a))
-> (BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a) -> b)
-> Vector Vector n a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
(Vector Vector n (CT a))
(Vector Vector n (CT a))
(Vector Vector n a)
-> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a)
RevDiff
(Dual (Tangent (Vector Vector n a)))
(Dual (Tangent (Vector Vector n a)))
(Vector Vector n a)
-> BoxedVector n (RevDiff (Vector Vector n (CT a)) (CT a) a)
forall b (n :: Natural) a c.
(Additive b, KnownNat n) =>
RevDiff a (BoxedVector n b) (BoxedVector n c)
-> BoxedVector n (RevDiff a b c)
boxedVectorArg
mkStreamVal :: (a -> b) -> Stream a -> Stream b
mkStreamVal :: forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal = (a -> b) -> Stream a -> Stream b
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
streamVal ::
(Multiplicative b) =>
Stream (RevDiff a b c) ->
Stream a
streamVal :: forall b a c.
Multiplicative b =>
Stream (RevDiff a b c) -> Stream a
streamVal = (RevDiff a b c -> a) -> Stream (RevDiff a b c) -> Stream a
forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance
(AutoDifferentiableValue a) =>
AutoDifferentiableValue (Stream a)
where
type DerivativeValue (Stream a) = Stream (DerivativeValue a)
autoVal :: Stream a -> Stream (DerivativeValue a)
autoVal :: Stream a -> Stream (DerivativeValue a)
autoVal = (a -> DerivativeValue a) -> Stream a -> Stream (DerivativeValue a)
forall a b. (a -> b) -> Stream a -> Stream b
mkStreamVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
streamValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> Stream (RevDiff b c d)) ->
DerivativeArg a ->
Stream b
streamValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> Stream (RevDiff b c d)) -> DerivativeArg a -> Stream b
streamValDerivative = (Stream (RevDiff b c d) -> Stream b)
-> (a -> Stream (RevDiff b c d)) -> DerivativeArg a -> Stream b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative Stream (RevDiff b c d) -> Stream b
forall b a c.
Multiplicative b =>
Stream (RevDiff a b c) -> Stream a
streamVal
streamArg ::
(Additive b) =>
RevDiff a (FiniteSupportStream b) (Stream c) ->
Stream (RevDiff a b c)
streamArg :: forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg (MkRevDiff Stream c
x FiniteSupportStream b -> a
bpc) =
RevDiff a b c -> Stream (RevDiff a b c) -> Stream (RevDiff a b c)
forall a. a -> Stream a -> Stream a
DS.Cons
(c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x_head b -> a
bpc_head)
(RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg (Stream c
-> (FiniteSupportStream b -> a)
-> RevDiff a (FiniteSupportStream b) (Stream c)
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff Stream c
x_tail FiniteSupportStream b -> a
bpc_tail))
where
x_head :: c
x_head = Stream c -> c
forall a. Stream a -> a
DS.head Stream c
x
x_tail :: Stream c
x_tail = Stream c -> Stream c
forall a. Stream a -> Stream a
DS.tail Stream c
x
bpc_head :: b -> a
bpc_head = FiniteSupportStream b -> a
bpc (FiniteSupportStream b -> a)
-> (b -> FiniteSupportStream b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> FiniteSupportStream b
forall a. a -> FiniteSupportStream a
singleton
bpc_tail :: FiniteSupportStream b -> a
bpc_tail = FiniteSupportStream b -> a
bpc (FiniteSupportStream b -> a)
-> (FiniteSupportStream b -> FiniteSupportStream b)
-> FiniteSupportStream b
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> FiniteSupportStream b -> FiniteSupportStream b
forall a. a -> FiniteSupportStream a -> FiniteSupportStream a
cons b
forall a. Additive a => a
zero
mkStreamArg ::
(Additive b) =>
(RevDiff a b c -> d) ->
RevDiff a (FiniteSupportStream b) (Stream c) ->
Stream d
mkStreamArg :: forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (FiniteSupportStream b) (Stream c) -> Stream d
mkStreamArg RevDiff a b c -> d
f = (RevDiff a b c -> d) -> Stream (RevDiff a b c) -> Stream d
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiff a b c -> d
f (Stream (RevDiff a b c) -> Stream d)
-> (RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c))
-> RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg
instance
(AutoDifferentiableArgument a) =>
AutoDifferentiableArgument (Stream a)
where
type DerivativeRoot (Stream a) = DerivativeRoot a
type DerivativeCoarg (Stream a) = FiniteSupportStream (DerivativeCoarg a)
type DerivativeArg (Stream a) = Stream (DerivativeArg a)
autoArg :: RevDiff (DerivativeRoot a) (FiniteSupportStream (DerivativeCoarg a)) (Stream (DerivativeArg a)) -> Stream a
autoArg :: RevDiff
(DerivativeRoot a)
(FiniteSupportStream (DerivativeCoarg a))
(Stream (DerivativeArg a))
-> Stream a
autoArg = (RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a)
-> RevDiff
(DerivativeRoot a)
(FiniteSupportStream (DerivativeCoarg a))
(Stream (DerivativeArg a))
-> Stream a
forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (FiniteSupportStream b) (Stream c) -> Stream d
mkStreamArg RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
streamArgDerivative ::
(AutoDifferentiableValue b, Additive (CT a)) =>
(Stream (RevDiff' (Stream a) a) -> b) ->
Stream a ->
DerivativeValue b
streamArgDerivative :: forall b a.
(AutoDifferentiableValue b, Additive (CT a)) =>
(Stream (RevDiff' (Stream a) a) -> b)
-> Stream a -> DerivativeValue b
streamArgDerivative = (RevDiff
(Dual (Tangent (Stream a))) (Dual (Tangent (Stream a))) (Stream a)
-> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a))
-> (Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a) -> b)
-> Stream a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
(FiniteSupportStream (CT a))
(FiniteSupportStream (CT a))
(Stream a)
-> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a)
RevDiff
(Dual (Tangent (Stream a))) (Dual (Tangent (Stream a))) (Stream a)
-> Stream (RevDiff (FiniteSupportStream (CT a)) (CT a) a)
forall b a c.
Additive b =>
RevDiff a (FiniteSupportStream b) (Stream c)
-> Stream (RevDiff a b c)
streamArg
mkFiniteSupportStreamVal :: (a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal :: forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal = (a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
unsafeMap
finiteSupportStreamVal ::
(Multiplicative b) =>
FiniteSupportStream (RevDiff a b c) ->
FiniteSupportStream a
finiteSupportStreamVal :: forall b a c.
Multiplicative b =>
FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
finiteSupportStreamVal = (RevDiff a b c -> a)
-> FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance
(AutoDifferentiableValue a) =>
AutoDifferentiableValue (FiniteSupportStream a)
where
type DerivativeValue (FiniteSupportStream a) = FiniteSupportStream (DerivativeValue a)
autoVal :: FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
autoVal :: FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
autoVal = (a -> DerivativeValue a)
-> FiniteSupportStream a -> FiniteSupportStream (DerivativeValue a)
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
mkFiniteSupportStreamVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
finiteSupportStreamValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> FiniteSupportStream (RevDiff b c d)) ->
DerivativeArg a ->
FiniteSupportStream b
finiteSupportStreamValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> FiniteSupportStream (RevDiff b c d))
-> DerivativeArg a -> FiniteSupportStream b
finiteSupportStreamValDerivative = (FiniteSupportStream (RevDiff b c d) -> FiniteSupportStream b)
-> (a -> FiniteSupportStream (RevDiff b c d))
-> DerivativeArg a
-> FiniteSupportStream b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative FiniteSupportStream (RevDiff b c d) -> FiniteSupportStream b
forall b a c.
Multiplicative b =>
FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream a
finiteSupportStreamVal
finiteSupportStreamArg ::
(Additive b) =>
RevDiff a (Stream b) (FiniteSupportStream c) ->
FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg :: forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg (MkRevDiff (MkFiniteSupportStream Vector c
arrX) Stream b -> a
bpc) =
Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c)
forall a. Vector a -> FiniteSupportStream a
MkFiniteSupportStream (Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c))
-> Vector (RevDiff a b c) -> FiniteSupportStream (RevDiff a b c)
forall a b. (a -> b) -> a -> b
$ (Int -> c -> RevDiff a b c) -> Vector c -> Vector (RevDiff a b c)
forall a b. (Int -> a -> b) -> Vector a -> Vector b
DV.imap Int -> c -> RevDiff a b c
f Vector c
arrX
where
f :: Int -> c -> RevDiff a b c
f Int
i c
x = c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (Stream b -> a
bpc (Stream b -> a) -> (b -> Stream b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> b -> Stream b
forall {t} {a}.
(Eq t, Additive a, Additive t, Num t) =>
t -> a -> Stream a
cStream Int
i)
cStream :: t -> a -> Stream a
cStream t
i a
cy = t -> Stream a
go t
0
where
go :: t -> Stream a
go t
n = a -> Stream a -> Stream a
forall a. a -> Stream a -> Stream a
DS.Cons (if t
i t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
n then a
cy else a
forall a. Additive a => a
zero) (t -> Stream a
go (t
n t -> t -> t
forall a. Additive a => a -> a -> a
+ t
1))
mkFiniteSupportStreamArg ::
(Additive b) =>
(RevDiff a b c -> d) ->
RevDiff a (Stream b) (FiniteSupportStream c) ->
FiniteSupportStream d
mkFiniteSupportStreamArg :: forall b a c d.
Additive b =>
(RevDiff a b c -> d)
-> RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream d
mkFiniteSupportStreamArg RevDiff a b c -> d
f = (RevDiff a b c -> d)
-> FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream d
forall a b.
(a -> b) -> FiniteSupportStream a -> FiniteSupportStream b
unsafeMap RevDiff a b c -> d
f (FiniteSupportStream (RevDiff a b c) -> FiniteSupportStream d)
-> (RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c))
-> RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg
instance
(AutoDifferentiableArgument a) =>
AutoDifferentiableArgument (FiniteSupportStream a)
where
type DerivativeRoot (FiniteSupportStream a) = DerivativeRoot a
type DerivativeCoarg (FiniteSupportStream a) = Stream (DerivativeCoarg a)
type DerivativeArg (FiniteSupportStream a) = FiniteSupportStream (DerivativeArg a)
autoArg :: RevDiff (DerivativeRoot a) (Stream (DerivativeCoarg a)) (FiniteSupportStream (DerivativeArg a)) -> FiniteSupportStream a
autoArg :: RevDiff
(DerivativeRoot a)
(Stream (DerivativeCoarg a))
(FiniteSupportStream (DerivativeArg a))
-> FiniteSupportStream a
autoArg = RevDiff
(DerivativeRoot a)
(Stream (DerivativeCoarg a))
(FiniteSupportStream (DerivativeArg a))
-> FiniteSupportStream a
forall a. HasCallStack => a
undefined
finiteSupportStreamArgDerivative ::
(AutoDifferentiableValue b, Additive (CT a)) =>
(FiniteSupportStream (RevDiff' (FiniteSupportStream a) a) -> b) ->
FiniteSupportStream a ->
DerivativeValue b
finiteSupportStreamArgDerivative :: forall b a.
(AutoDifferentiableValue b, Additive (CT a)) =>
(FiniteSupportStream (RevDiff' (FiniteSupportStream a) a) -> b)
-> FiniteSupportStream a -> DerivativeValue b
finiteSupportStreamArgDerivative = (RevDiff
(Dual (Tangent (FiniteSupportStream a)))
(Dual (Tangent (FiniteSupportStream a)))
(FiniteSupportStream a)
-> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a))
-> (FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a) -> b)
-> FiniteSupportStream a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff (Stream (CT a)) (Stream (CT a)) (FiniteSupportStream a)
-> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a)
RevDiff
(Dual (Tangent (FiniteSupportStream a)))
(Dual (Tangent (FiniteSupportStream a)))
(FiniteSupportStream a)
-> FiniteSupportStream (RevDiff (Stream (CT a)) (CT a) a)
forall b a c.
Additive b =>
RevDiff a (Stream b) (FiniteSupportStream c)
-> FiniteSupportStream (RevDiff a b c)
finiteSupportStreamArg
mkMaybeVal :: (a -> b) -> Maybe a -> Maybe b
mkMaybeVal :: forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal = (a -> b) -> Maybe a -> Maybe b
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
maybeVal ::
(Multiplicative b) =>
Maybe (RevDiff a b c) ->
Maybe a
maybeVal :: forall b a c. Multiplicative b => Maybe (RevDiff a b c) -> Maybe a
maybeVal = (RevDiff a b c -> a) -> Maybe (RevDiff a b c) -> Maybe a
forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal RevDiff a b c -> a
forall b a c. Multiplicative b => RevDiff a b c -> a
scalarVal
instance
(AutoDifferentiableValue a) =>
AutoDifferentiableValue (Maybe a)
where
type DerivativeValue (Maybe a) = Maybe (DerivativeValue a)
autoVal :: Maybe a -> Maybe (DerivativeValue a)
autoVal :: Maybe a -> Maybe (DerivativeValue a)
autoVal = (a -> DerivativeValue a) -> Maybe a -> Maybe (DerivativeValue a)
forall a b. (a -> b) -> Maybe a -> Maybe b
mkMaybeVal a -> DerivativeValue a
forall a. AutoDifferentiableValue a => a -> DerivativeValue a
autoVal
maybeArg :: RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg :: forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg (MkRevDiff Maybe c
maybeX Maybe b -> a
bpc) = case Maybe c
maybeX of
Just c
x -> RevDiff a b c -> Maybe (RevDiff a b c)
forall a. a -> Maybe a
Just (c -> (b -> a) -> RevDiff a b c
forall a b c. c -> (b -> a) -> RevDiff a b c
MkRevDiff c
x (Maybe b -> a
bpc (Maybe b -> a) -> (b -> Maybe b) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> Maybe b
forall a. a -> Maybe a
Just))
Maybe c
Nothing -> Maybe (RevDiff a b c)
forall a. Maybe a
Nothing
mkMaybeArg ::
(RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg :: forall a b c d.
(RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg RevDiff a b c -> d
f = (RevDiff a b c -> d) -> Maybe (RevDiff a b c) -> Maybe d
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap RevDiff a b c -> d
f (Maybe (RevDiff a b c) -> Maybe d)
-> (RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c))
-> RevDiff a (Maybe b) (Maybe c)
-> Maybe d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg
instance
(AutoDifferentiableArgument a) =>
AutoDifferentiableArgument (Maybe a)
where
type DerivativeRoot (Maybe a) = DerivativeRoot a
type DerivativeCoarg (Maybe a) = Maybe (DerivativeCoarg a)
type DerivativeArg (Maybe a) = Maybe (DerivativeArg a)
autoArg :: RevDiff (DerivativeRoot a) (Maybe (DerivativeCoarg a)) (Maybe (DerivativeArg a)) -> Maybe a
autoArg :: RevDiff
(DerivativeRoot a)
(Maybe (DerivativeCoarg a))
(Maybe (DerivativeArg a))
-> Maybe a
autoArg = (RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a)
-> RevDiff
(DerivativeRoot a)
(Maybe (DerivativeCoarg a))
(Maybe (DerivativeArg a))
-> Maybe a
forall a b c d.
(RevDiff a b c -> d) -> RevDiff a (Maybe b) (Maybe c) -> Maybe d
mkMaybeArg RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
forall a.
AutoDifferentiableArgument a =>
RevDiff (DerivativeRoot a) (DerivativeCoarg a) (DerivativeArg a)
-> a
autoArg
maybeArgDerivative ::
(AutoDifferentiableValue b) =>
(Maybe (RevDiff' (Maybe a) a) -> b) ->
Maybe a ->
DerivativeValue b
maybeArgDerivative :: forall b a.
AutoDifferentiableValue b =>
(Maybe (RevDiff' (Maybe a) a) -> b) -> Maybe a -> DerivativeValue b
maybeArgDerivative = (RevDiff
(Dual (Tangent (Maybe a))) (Dual (Tangent (Maybe a))) (Maybe a)
-> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a))
-> (Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
-> b)
-> Maybe a
-> DerivativeValue b
forall c a b.
AutoDifferentiableValue c =>
(RevDiff (CT a) (CT a) a -> b)
-> (b -> c) -> a -> DerivativeValue c
customArgDerivative RevDiff
(Maybe (Dual (Tangent a))) (Maybe (Dual (Tangent a))) (Maybe a)
-> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
RevDiff
(Dual (Tangent (Maybe a))) (Dual (Tangent (Maybe a))) (Maybe a)
-> Maybe (RevDiff (Maybe (Dual (Tangent a))) (Dual (Tangent a)) a)
forall a b c.
RevDiff a (Maybe b) (Maybe c) -> Maybe (RevDiff a b c)
maybeArg
maybeValDerivative ::
( AutoDifferentiableArgument a,
Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)
) =>
(a -> Maybe (RevDiff b c d)) ->
DerivativeArg a ->
Maybe b
maybeValDerivative :: forall a c b d.
(AutoDifferentiableArgument a, Multiplicative c,
DerivativeCoarg a ~ CT (DerivativeArg a),
DerivativeRoot a ~ CT (DerivativeArg a)) =>
(a -> Maybe (RevDiff b c d)) -> DerivativeArg a -> Maybe b
maybeValDerivative = (Maybe (RevDiff b c d) -> Maybe b)
-> (a -> Maybe (RevDiff b c d)) -> DerivativeArg a -> Maybe b
forall b c d.
(DerivativeRoot b ~ CT (DerivativeArg b),
DerivativeCoarg b ~ CT (DerivativeArg b),
AutoDifferentiableArgument b) =>
(c -> d) -> (b -> c) -> DerivativeArg b -> d
customValDerivative Maybe (RevDiff b c d) -> Maybe b
forall b a c. Multiplicative b => Maybe (RevDiff a b c) -> Maybe a
maybeVal