{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE NoImplicitPrelude #-} {-# OPTIONS -Wno-unused-imports #-} {-# OPTIONS_HADDOCK show-extensions #-} -- | -- Module : Numeric.InfBackprop.Tutorial -- Copyright : (C) 2023-2025 Alexey Tochin -- License : BSD3 (see the file LICENSE) -- Maintainer : Alexey Tochin <Alexey.Tochin@gmail.com> -- -- Tutorial for the -- [inf-backprop](https://hackage.haskell.org/package/inf-backprop) package. module Numeric.InfBackprop.Tutorial ( -- * Quick Start -- ** Basic Examples #quick-start-simple-derivative# -- $quick-start-simple-derivative -- ** Derivatives for Symbolic Expressions #quick-start-derivatives-for-symbolic-expressions# -- $quick-start-derivatives-for-symbolic-expressions -- ** Symbolic Expressions Visualization #squick-start-ymbolic-expressions-visualization# -- $quick-start-symbolic-expressions-visualization -- ** Gradient over a Two-Argument Function #quick-start-function-of-two-argument-functions# -- $quick-start-gradient-of-two-argument-functions -- ** Siskind-Pearlmutter Example #quick-start-siskind-pearlmutter-example# -- $quick-start-siskind-pearlmutter-example -- * How it Works -- $how-it-works -- ** The Backpropagation Derivative #how-it-works-backpropagation# -- $how-it-works-backpropagation -- ** Core Type: `RevDiff` #how-it-works-core-type-RevDiff# -- $how-it-works-core-type-RevDiff -- ** Functions Overloading #how-it-works-functions-overloading# -- $how-it-works-functions-overloading -- ** Tangent and Cotangent Spaces #how-it-works-tangent-space# -- $how-it-works-tangent-space -- * Differentiation for Structured Types #differentiation-for-structured-types# -- $differentiation-for-structured-types -- ** Structured Value Type #differentiation-for-structured-types-structured-value# -- $differentiation-for-structured-types-structured-value -- *** Basic Examples: Structured Value Types #differentiation-for-structured-types-structured-value-basic-examples# -- $differentiation-for-structured-types-structured-value-basic-examples -- *** Custom Derivative: Structured Value Types #differentiation-for-structured-types-structured-value-custom-derivative# -- $differentiation-for-structured-types-structured-value-custom-derivative -- *** How it Works: Structured Value Types #differentiation-for-structured-types-structured-value-how-it-works# -- $differentiation-for-structured-types-structured-value-how-it-works -- *** Defining Custom Value Type #differentiation-for-structured-types-structured-value-defining-custom-value-type# -- $differentiation-for-structured-types-structured-value-defining-custom-value-type -- ** Structured Argument Type #differentiation-for-structured-types-structured-argument-type# -- $differentiation-for-structured-types-structured-argument-type -- *** Basic Examples: Structured Argument Types #differentiation-for-structured-types-structured-argument-type-basic-examples# -- $differentiation-for-structured-types-structured-argument-type-basic-examples -- *** Custom Derivative: Structured Argument Types #differentiation-for-structured-types-structured-argument-type-custom-gradient# -- $differentiation-for-structured-types-structured-argument-type-custom-gradient -- *** How it Works: Structered Argument Types #differentiation-for-structured-types-structured-argument-type-how-it-works# -- $differentiation-for-structured-types-structured-argument-type-how-it-works -- *** Defining Custom Argument Type #differentiation-for-structured-types-structured-argument-type-defining-custom-type# -- $differentiation-for-structured-types-structured-argument-type-defining-custom-type -- * Performance Remarks #performance-remarks# -- $performance-remarks -- ** Subexpression Elimination #sperformance-remarks-ubexpression-elimination# -- $performance-remarks-subexpression-elimination -- ** Forward Step Results Reusage #forward-step-results-reusage# -- $performance-remarks-forward-step-results-reusage -- * What is Next #what-is-next# -- $what-is-next ) where import Control.Category ((>>>)) import Control.Lens (set, view) import Data.FiniteSupportStream (FiniteSupportStream (toVector), head, singleton) import qualified Data.List as DL import Data.Proxy (Proxy (Proxy)) import Data.Stream (Stream, fromList, head, take) import qualified Data.Stream as DS import qualified Data.Stream as Data import Data.Tuple (fst, snd, uncurry) import Data.Type.Equality (type (~)) import qualified Data.Vector as DV import qualified Data.Vector.Fixed as DVF import Data.Vector.Generic.Sized (Vector, foldl') import qualified Data.Vector.Generic.Sized as DVGS import Debug.DiffExpr ( BinarySymbolicFunc, SymbolicFunc, TSE, TracedSimpleExpr, binarySymbolicFunc, unarySymbolicFunc, ) import Debug.SimpleExpr (SE, SimpleExpr, number, simplify, simplifyExpr, variable) import Debug.SimpleExpr.Utils.Traced (Traced (MkTraced), addTraceUnary) import Debug.Trace (trace) import GHC.Base (Float, Int, const, fmap, foldr, id, undefined, ($), (.), (<>)) import GHC.Integer (Integer) import GHC.Natural (Natural, minusNatural) import GHC.Show (Show (show)) import GHC.TypeNats (KnownNat) import NumHask ( Additive, Distributive, Divisive, ExpField, FromInteger (fromInteger), Multiplicative, Ring, Subtractive, TrigField, cos, cosh, exp, log, negate, one, sin, sinh, zero, (*), (+), (-), (/), ) import qualified NumHask as NH import Numeric.InfBackprop ( CT, Cotangent, Dual, RevDiff (MkRevDiff), RevDiff', Tangent, autoArg, autoVal, backprop, boxedVectorArg, boxedVectorArgDerivative, boxedVectorVal, constDiff, customArgDerivative, customArgValDerivative, fromProfunctors, fromVanLaarhoven, initDiff, mkBoxedVectorArg, mkBoxedVectorVal, mkStreamArg, mkStreamVal, mkTupleArg, mkTupleVal, scalarArg, scalarArgDerivative, scalarVal, scalarValDerivative, simpleDerivative, simpleDifferentiableFunc, simpleValueAndDerivative, stopDiff, streamArg, streamArgDerivative, toProfunctors, toVanLaarhoven, tupleArg, tupleArgDerivative, tupleVal, twoArgsDerivative, twoArgsDerivativeOverY, value, ) import Numeric.InfBackprop.Instances.NumHask () import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) import Numeric.InfBackprop.Utils.Tuple (cross) -- $quick-start-simple-derivative -- -- >>> import GHC.Base (Float, fmap, ($)) -- -- In this section, we'll explore how automatic differentiation transforms ordinary -- mathematical functions into their derivatives, handling everything from basic -- polynomials to complex compositions without requiring manual derivation. -- -- We'll start by exploring automatic differentiation -- through the familiar square function: -- -- \[ -- f(x) := x^2 -- \] -- -- To work with our automatic differentiation system, we need operations that can -- handle not just numbers, but also the dual numbers that carry derivative information. -- The polymorphic multiplication operator from -- [numhask](https://hackage.haskell.org/package/numhask) -- provides this flexibility: -- -- >>> import NumHask (Multiplicative, (*), (+), log) -- -- The operator `(*)` has the following type signature: -- -- > (*) :: Multiplicative a => a -> a -> a -- -- This polymorphic operator allows us to write functions -- that work seamlessly with both regular -- numbers and the extended number types used in automatic differentiation. -- -- Let's define our square function and see it in action with regular `Float` values: -- -- >>> f x = x * x -- >>> fmap f [-3, -2, -1, 0, 1, 2, 3] :: [Float] -- [9.0,4.0,1.0,0.0,1.0,4.0,9.0] -- -- Now comes the remarkable part: computing the derivative automatically. The -- `simpleDerivative` function applies the chain rule automatically, transforming -- any function built from differentiable primitives into its derivative function. -- -- We know from calculus that the derivative of \(x^2\) should be: -- -- \[ -- f'(x) = 2 \cdot x -- \] -- -- Let's verify this using automatic differentiation: -- -- >>> import Numeric.InfBackprop (simpleDerivative) -- -- >>> f' = simpleDerivative f :: Float -> Float -- >>> fmap f' [-3, -2, -1, 0, 1, 2, 3] -- [-6.0,-4.0,-2.0,0.0,2.0,4.0,6.0] -- -- Notice how each result equals \(2x\), perfectly confirming our analytical derivative -- \(f'(x) = 2x\). The values \(-6, -4, -2, 0, 2, 4, 6\) correspond exactly to \(2\) times -- each input value. -- -- You must provide a type annotation (such as @Float -> Float@) for the derivative -- function. This ensures correct type inference by the compiler and specifies which -- numeric type you want to work with. -- -- Computing higher-order derivatives follows the same pattern. Since composing -- `simpleDerivative` twice gives us the second derivative, this demonstrates how -- automatic differentiation naturally handles higher-order derivatives through -- function composition. -- -- For our square function, the second derivative should be the constant \(2\): -- -- \[ -- f''(x) = 2 -- \] -- -- >>> f'' = simpleDerivative $ simpleDerivative f :: Float -> Float -- >>> fmap f'' [-3, -2, -1, 0, 1, 2, 3] -- [2.0,2.0,2.0,2.0,2.0,2.0,2.0] -- -- Perfect! The constant value @2.0@ across all inputs confirms that our second -- derivative is indeed the constant function \(f''(x) = 2\). -- -- This approach scales naturally to arbitrarily complex functions. Let's explore -- how automatic differentiation handles function composition by examining a more -- intricate example involving logarithms and polynomial terms: -- -- \[ -- g(x) := \log (x^2 + x^3) -- \] -- -- We'll use integer powers from the -- 'Numeric.InfBackprop.Algebra.IntegralPower' -- module: -- -- >>> import Debug.SimpleExpr.Utils.Algebra ((^)) -- -- >>> g x = log (x ^ 2 + x ^ 3) -- >>> g' = simpleDerivative g :: Float -> Float -- >>> g 1 :: Float -- 0.6931472 -- >>> g' 1 :: Float -- 2.5 -- -- We can verify this result analytically. The derivative of \(\log(x^2 + x^3)\) using -- the chain rule is: -- -- \[ -- g'(x) = \frac{d}{dx}[\log(x^2 + x^3)] = \frac{1}{x^2 + x^3} \cdot \frac{d}{dx}[x^2 + x^3] = \frac{2x + 3x^2}{x^2 + x^3} -- \] -- -- At \(x = 1\): -- -- \[ -- g'(1) = \frac{2 \cdot 1 + 3 \cdot 1^2}{1^2 + 1^3} = \frac{2 + 3}{1 + 1} = \frac{5}{2} = 2.5 -- \] -- -- The automatic differentiation result matches our analytical calculation perfectly, -- demonstrating how the system correctly applies the chain rule even for complex -- composite functions. -- $quick-start-derivatives-for-symbolic-expressions -- -- >>> import NumHask ((*), sin, cos) -- >>> import Debug.SimpleExpr (variable, simplify, SimpleExpr, SE) -- >>> import Numeric.InfBackprop (simpleDerivative) -- -- In many cases, it is more convenient to illustrate differentiation using -- symbolic expressions rather than concrete numeric values. -- Unlike numeric differentiation, symbolic expressions allow us to -- inspect, transform, and optimize derivatives algebraically. -- -- We use the -- [simple-expr](https://hackage.haskell.org/package/simple-expr) -- package to construct and manipulate symbolic expressions. -- -- For example, consider the function: -- -- \[ -- f(x) := \sin(x^2) -- \] -- -- We can define it symbolically as follows: -- -- >>> import Debug.SimpleExpr.Utils.Algebra (AlgebraicPower, (^)) -- -- >>> x = variable "x" :: SimpleExpr -- >>> f x = sin (x ^ 2) -- >>> f x :: SimpleExpr -- sin(x^2) -- -- where 'SimpleExpr' is a symbolic expression type from -- [simple-expr](https://hackage.haskell.org/package/simple-expr) -- -- Computing the symbolic derivative -- -- \[ -- f'(x) := 2x \cdot \cos(x^2) -- \] -- -- is equally straightforward: -- -- >>> f' = simpleDerivative f -- >>> simplify $ f' x :: SimpleExpr -- (2*x)*cos(x^2) -- -- The `simplify` function from -- [simple-expr](https://hackage.haskell.org/package/simple-expr) -- reduces redundant expressions like -- @*1@ and @+0@. -- and presents the result in a more readable algebraic form. -- -- Bellow, we will use the @SE@ type synonym for @SimpleExpr@. -- -- Note that we continue to use generic definitions of functions like 'cos', -- as well as operators such as '(*)', from the -- [numhask](https://hackage.haskell.org/package/numhask) -- package. -- $quick-start-symbolic-expressions-visualization -- -- The -- [simple-expr](https://hackage.haskell.org/package/simple-expr) -- package includes visualization tools that help illustrate the process of symbolic -- differentiation. -- -- >>> import Debug.SimpleExpr (SimpleExpr, variable, simplify, plotExpr, plotDGraphPng) -- >>> import Debug.DiffExpr (unarySymbolicFunc) -- >>> import Numeric.InfBackprop (simpleDerivative) -- -- As a warm-up, consider a simple composition of two symbolic functions: -- -- \[ -- x \mapsto g(f(x)) -- \] -- -- This can be represented as: -- -- >>> x = variable "x" :: SimpleExpr -- >>> f = unarySymbolicFunc "f" :: SimpleExpr -> SimpleExpr -- >>> g = unarySymbolicFunc "g" :: SimpleExpr -> SimpleExpr -- >>> g (f x) :: SimpleExpr -- g(f(x)) -- -- You can visualize this expression using: -- -- > plotExpr $ g (f x) -- --  -- -- To visualize the first derivative of this composition, use: -- -- > plotExpr $ simplify $ simpleDerivative (g . f) x -- --  -- -- Visualizing the second derivative is just as easy: -- -- > plotExpr $ simplify $ simpleDerivative (simpleDerivative (g . f)) x -- --  -- $quick-start-gradient-of-two-argument-functions -- -- In this section, we focus on computing partial derivatives of functions -- with two arguments. -- -- As a starting point, consider a symbolic function @h@ that takes two arguments. -- (See -- [Derivatives for Symbolic Expressions](#g:quick-45-start-45-derivatives-45-for-45-symbolic-45-expressions) -- .) -- -- >>> x = variable "x" -- >>> y = variable "y" -- >>> h = binarySymbolicFunc "h" :: BinarySymbolicFunc a => a -> a -> a -- >>> h x y -- h(x,y) -- -- To compute partial derivatives, we use the `twoArgsDerivative` operator, -- which has a somewhat advanced type signature (see Section ??? for details). -- In practice, its usage is straightforward: -- -- >>> :{ -- h' :: SE -> SE -> (SE, SE) -- h' x y = simplify $ twoArgsDerivative h x y -- :} -- -- This returns a pair of partial derivatives: -- -- >>> h' x y -- (h'_1(x,y),h'_2(x,y)) -- -- We can also compute the second-order derivatives by nesting `twoArgsDerivative`: -- -- >>> :{ -- h'' :: SE -> SE -> ((SE, SE), (SE, SE)) -- h'' x y = simplify $ twoArgsDerivative (twoArgsDerivative h) x y -- :} -- -- >>> h'' x y -- ((h'_1'_1(x,y),h'_1'_2(x,y)),(h'_2'_1(x,y),h'_2'_2(x,y))) -- -- In this example, @h\'_1\'_2@ refers to the second partial derivative of @h@ -- with respect to @x@ and then @y@, and so on. -- -- Note that `twoArgsDerivative` is polymorphic over the return type of the function, -- but it works only for functions that take exactly /two arguments/. -- -- In contrast, the @customArgValDerivative autoArg@ operator -- (see -- [Structured Argument Type](#g:differentiation-45-for-45-structured-45-types-45-structured-45-argument-45-type)) -- can handle functions of arbitrary arity, but it is /not/ polymorphic -- over the return type of the function. -- $quick-start-siskind-pearlmutter-example -- -- We are now ready to revisit a classic example of higher-order automatic differentiation -- from the paper by Siskind and Pearlmutter: -- [Siskind & Pearlmutter (2005), "Perturbation Confusion and Referential Transparency"](https://engineering.purdue.edu/~qobi/papers/ifl2005.pdf) -- The expression of interest is: -- -- \[ -- \left. -- \frac{\partial}{\partial x} -- \left( -- x -- \left( -- \left. -- \frac{\partial}{\partial y} -- \left( -- x + y -- \right) -- \right|_{y=1} -- \right) -- \right) -- \right|_{x=1} -- = 1 -- \] -- -- To implement this, we begin by applying the partial derivative operator -- `twoArgsDerivativeOverY`, which differentiates a binary function -- with respect to its /second/ argument: -- -- For example, to compute -- -- \[ -- \frac{\partial}{\partial y} -- (x \cdot y) -- = x -- \] -- -- we write: -- -- >>> x = variable "x" -- >>> y = variable "y" -- >>> simplify $ twoArgsDerivativeOverY (*) x y :: SE -- x -- -- To evaluate this derivative at @y = 1@, we can use the `stopDiff` function, -- which performs symbolic substitution. For instance, -- -- > stopDiff $ number 1 -- -- effectively replaces @y@ with @1@ in the expression. -- -- So the expression -- -- \[ -- \left. -- \frac{\partial}{\partial y} -- (x \cdot y) -- \right|_{y=1} -- = x -- \] -- -- is implemented as: -- -- >>> simplify $ twoArgsDerivativeOverY (*) x 1 :: SE -- x -- -- Now we can wrap the entire expression in a derivative with respect to @x@: -- -- \[ -- \frac{d}{dx} -- \left. -- \frac{\partial}{\partial y} -- (x \cdot y) -- \right|_{y=1} -- = 1 -- \] -- -- This becomes: -- -- >>> :{ -- simplify $ -- (simpleDerivative $ \x_ -> twoArgsDerivativeOverY (*) x_ 1) x :: SE -- :} -- 1 -- -- The same logic works not just for symbolic expressions (`SE`), -- but also for concrete numeric types such as `Float`: -- -- >>> :{ -- simpleDerivative -- (\x -> x * twoArgsDerivativeOverY (+) x 1) -- (2024 :: Float) -- :} -- 1.0 -- -- Note: when working with numeric types like `Float`, -- the variable @x@ must be assigned a concrete `Float` value. -- $how-it-works-backpropagation -- -- To clarify the concept of backpropagation, consider the following example. -- -- Let @h@, @f@, and @g@ be three simple functions of type: -- -- \[ -- \mathbb{R} \rightarrow \mathbb{R} -- \] -- -- Now consider their composition: -- -- \[ -- x \mapsto g(f(h(x))) -- \] -- -- The first derivative of this composition, using the chain rule, is: -- -- \[ -- x \mapsto h'(x) \cdot f'(h(x)) \cdot g'(f(h(x))) -- \] -- -- This composition and its derivative can be illustrated -- using the following computation graph: -- --  -- -- The top path (from left to right) represents the /forward pass/, -- where values are computed through the function chain. -- The bottom path (from right to left) represents the /backward pass/, -- where derivatives are propagated. -- -- According to the backpropagation strategy, -- the derivative is computed in reverse order, as follows: -- -- 1. Evaluate @h(x)@. -- -- 2. Compute @f(h(x))@. -- -- 3. Compute @g(f(h(x)))@. -- -- 4. Compute the top derivative: @g'(f(h(x)))@. -- -- 5. Compute the next derivative: @f'(h(x))@. -- -- 6. Multiply: @g'(f(h(x))) * f'(h(x))@. -- -- 7. Compute the base derivative: @h'(x)@. -- -- 8. Multiply the result from step 6 by @h'(x)@. -- -- The product of these three derivatives gives the full derivative of the composition. -- -- Note: While it is possible to compute this derivative in forward order -- (i.e., from left to right) or -- any other order, -- the backpropagation strategy is more efficient -- for deep machine learning applications. -- Forward-mode differentiation is beyond the scope of this library. -- -- Generalizing this approach to longer function chains or functions from and to vector spaces -- is straightforward and follows the same principles. -- $how-it-works-core-type-RevDiff -- -- All the derivative computations from the previous example — -- specifically for the function @f@ — -- can be conceptually divided into two phases: -- -- 1. /Forward step/: Compute the value @f(h(x))@. -- -- 2. /Backward step/: -- Compute the derivative @f'(h(x))@, and multiply it by the previously -- obtained derivative @g'(f(h(x)))@. -- -- Note that the value @h(x)@ is used in both the forward and backward steps. -- -- The corresponding diagram can be visualized as: -- --  -- -- A differentiable function from type @a@ to type @b@ can be represented -- as a pair of functions: -- a /forward/ function and a /backward/ (derivative propagation) function: -- -- @ -- newtype DifferentiableFunc a b = MkDifferentiableFunc { -- forward :: a -> b, -- backward :: a -> CT b -> CT a -- } -- @ -- -- The meaning of the `CT` type family (short for /cotangent/) -- will be discussed in the -- [next section](#g:how-45-it-45-works-45-tangent-45-space). -- For now, you may assume @CT a ~ a@. -- -- From a categorical perspective, a @DifferentiableFunc@ behaves like a lens: -- -- > DifferentiableFunc a b ≈ Lens a (CT a) b (CT b) -- -- where @forward@ corresponds to `view`, and @backward@ corresponds to `set`. -- -- In principle, one could define a category of differentiable functions using lenses, -- replacing standard function composition `(.)` with lens composition `(% or >>>)`. -- However, this comes at a cost: we lose the ability to use familiar syntax such as -- function application @y = f x@. -- -- To preserve the familiar function syntax — e.g., keeping definitions like -- @sin :: a -> a@ -- and supporting ordinary function application — we follow an approach inspired by the -- [ad](https://hackage.haskell.org/package/ad) -- and -- [backprop](https://hackage.haskell.org/package/backprop) libraries. -- See also, for example, -- [this article](https://arxiv.org/pdf/1804.00746) -- -- Fixing a type @t@ (which plays the role of the final output), we can reinterpret -- a lens-like function -- -- > dFunc :: DifferentiableFunc a b -- -- as a transformation on differentiable values: -- -- > lensToMap :: DifferentiableFunc a b -> DifferentiableFunc t a -> DifferentiableFunc t b -- > lensToMap dFunc = dFunc <<< -- lens composition -- -- So, @lensToMap dFunc@ becomes a plain Haskell function: -- -- > DifferentiableFunc t a -> DifferentiableFunc t b -- -- Mathematically, this is a /hom-functors/ from -- the cathegory of law-breaking lenses. -- -- Next, note that the type -- -- > DifferentiableFunc t a -- -- is isomorphic to: -- -- > t -> (a, CT a -> CT t) -- -- But the actual value of type @t@ is not used the composion with -- @DifferentiableFunc a b@. -- Therefore, we can drop the @t@ parameter and reduce the transformation: -- -- > DifferentiableFunc t a -> DifferentiableFunc t b -- -- to a plain function: -- -- > (a, CT a -> CT t) -> (b, CT b -> CT t) -- -- This motivates the definition of the core type: -- -- @ -- data RevDiff' t a = MkRevDiff -- { value :: a -- , backprop :: CT a -> CT t -- } -- @ -- -- For example, suppose we have a function: -- -- > f :: Float -> Float -- function f -- > f' :: Float -> Float -- derivative of f -- -- Then the differentiable version of @f@ can be defined as: -- -- @ -- differentiableF :: RevDiff' t Float -> RevDiff' t Float -- differentiableF (MkRevDiff x backprop) = -- MkRevDiff (f x) (\cx -> backprop ((f' x) * cx)) -- @ -- -- To evaluate the function at a point @x@: -- -- > y = value $ differentiableF (MkRevDiff x id) -- -- To evaluate its derivative at @x@: -- -- > y' = backprop (differentiableF (MkRevDiff x id)) 1.0 -- -- Here, the transition from type @a@ to @RevDiff' t a@ carries two parts: -- -- - `value`: the forward-pass result -- -- - `backprop`: a stack of backward-pass derivative transformations -- -- from type @CT a@ to @CT t@. -- -- In the example above, the value: -- -- > MkRevDiff x id :: RevDiff Float Float -- -- represents the /initial value/ of the backpropagation stack. -- Applying @differentiableF@ -- results in: -- -- > MkRevDiff (f x) (\cx -> id ((f' x) * cx)) = MkRevDiff (f x) ((f' x) *) -- -- So applying `backprop` to @1.0 :: @`Float` gives us the derivative value @f' x@. -- -- For convenience and flexibility, this package defines a three-parameter type: -- @ -- data RevDiff a b c = MkRevDiff {value :: c, backprop :: b -> a} -- @ -- -- This generalized structure allows us to separate the types involved in the -- forward pass (the value of type @c@) from those used in the backward pass -- (the gradient computation from @b@ to @a@). -- -- We also provide a specialized type alias for common use cases: -- -- > type RevDiff' a b = RevDiff (CT a) (CT b) b -- -- This three-parameter design enables powerful abstraction capabilities. -- In particular, it allows us to implement both profunctor and Van Laarhoven -- representations for differentiable functions, providing multiple ways to -- compose and manipulate automatic differentiation computations. -- -- These alternative representations can be accessed through the conversion -- functions `fromProfunctors`, `toProfunctors`, `fromVanLaarhoven`, and -- `toVanLaarhoven`, each offering different compositional properties suited -- to various use cases. -- -- Generalizing this to arbitrary compositions of differentiable functions -- is straightforward -- and follows the same backpropagation principle. -- $how-it-works-functions-overloading -- -- Our goal now is to make functions such as @sin@ and @(*)@ differentiable, -- while still being able to use them as ordinary functions — in particular, -- to apply them to arguments and compose them using '(.)'. -- -- To this end, we follow the approach used in the -- [numhask](https://hackage.haskell.org/package/numhask) package. -- In this package, functions like `sin` and `(*)` are defined as polymorphic methods -- of typeclasses. -- -- For instance, the function `sin` is a method of the typeclass: -- -- @ -- class TrigField a where -- ... -- sin :: a -> a -- @ -- -- Similarly, multiplication is defined via: -- -- @ -- class Multiplicative a where -- ... -- (*) :: a -> a -> a -- @ -- -- These typeclasses have instances, for example, for the type `Float`. -- Instancies for `SE` are provided in -- [simple-expr](https://hackage.haskell.org/package/simple-expr) -- package. -- -- To make `sin` and `(*)` differentiable in the backpropagation framework, -- it is enough to define instances for: -- -- > RevDiff Float Float -- -- These instances can be implemented as follows. -- (The type family 'CT' can be ignored for now, we may assume @CT a ~ a@ for simplicity.) -- -- @ -- instance Additive (CT t) => TrigField (RevDiff t Float) where -- ... -- sin :: RevDiff t Float -> RevDiff t Float -- sin MkRevDiff {value = x, backprop = backpropX} = MkRevDiff { -- value = sin x, -- backprop = backpropX . ((cos x) *) -- } -- @ -- -- @ -- instance Additive (CT t) => Multiplicative (RevDiff t Float) where -- ... -- (*) :: RevDiff t Float -> RevDiff t Float -> RevDiff t Float -- MkRevDiff x backpropX * MkRevDiff y backpropY = -- MkRevDiff { -- value = x * y, -- backprop = backpropX . (y *) + backpropY . (x *) -- } -- @ -- -- To compute /second derivatives/, we can use a nested type like: -- -- > RevDiff (RevDiff Float Float) (RevDiff Float Float) -- -- That is, the outer layer performs backpropagation through the inner derivative. -- Similarly, higher-order derivatives can be obtained by nesting @RevDiff@ types further. -- -- These instances can also be generalized to any numeric type @a@, not just `Float`, -- allowing us to define /infinitely differentiable/ functions. -- $how-it-works-tangent-space -- -- In this section, we explain the purpose of the type family `CT` and how it is used. -- In most practical cases, we can assume @CT a ~ a@ and safely ignore it. -- -- One of the challenges in automatic differentiation is that -- the value type of a function and the value type of its derivative -- may not coincide, even when the input is scalar (for example, a real number). -- From a mathematical perspective, this corresponds to the need to work with -- tangent and cotangent bundles. -- -- For instance, if a scalar-valued function takes a vector as input, -- its derivative is also vector-valued. -- However, this correspondence does not hold in general. -- -- Consider the case where the input is an infinite sequence -- (such as an infinte list or stream). -- The derivative of a function on such inputs is a finite-length sequence -- (a sparse or finite-support vector; see the `FiniteSupportStream` type). -- Conversely, a function on finite-support streams has a derivative -- that is generally represented as an infinite stream. -- -- This distinction arises because the convolution of two infinite streams -- is not defined in general. -- On the other hand, every linear functional on streams can be represented -- as a convolution with a finite-length vector. -- Conversely, a convolution with a finite-length vector defines -- a linear functional on infinite streams. -- -- Similarly, any linear functional on all bounded finite-length vectors -- can be represented as a convolution with an infinite sequence. -- And conversely, convolution with an infinite sequence yields -- a linear functional on finite-length vectors. -- -- These distinctions are not just mathematical formalisms, -- but real practical constraints. -- In particular, the convolution of two streams cannot be calculated. -- In this package, the Haskell type system cannot safely express, for example, -- that the derivative of a function over `Stream` should be of type `Stream`, -- or that the derivative of a function over `FiniteSupportStream` -- should also be of type `FiniteSupportStream`. -- -- Another example comes from geometry. -- Consider a function defined on the surface of a unit sphere in 3D space. -- In this case, the derivative at each point must lie in the tangent plane -- to the sphere at that point — not just any 3D vector. -- Therefore, the derivative type differs from the function's output type. -- -- More generally, in differential geometry, -- functions are defined on manifolds, -- and their derivatives take values in the cotangent bundle of the manifold. -- -- To model this distinction in Haskell, -- we introduce the type family `CT`, which stands for "cotangent type". -- For example: -- -- > CT Float = Float -- -- > CT (a, b) = (CT a, CT b) -- -- > CT (Vector a) = Vector (CT a) -- -- > CT (Stream a) = FiniteSupportStream (CT a) -- -- > CT (FiniteSupportStream a) = Stream (CT a) -- -- > CT (E2NormedVector a) = Vector (CT a) -- -- The type family `CT` is defined as a composition of two type families: -- `Tangent` and `Dual`: -- -- > CT a = Dual (Tangent a) -- -- The `Tangent` family describes the type of tangent vectors. -- For example: -- -- > Tangent Float = Float -- -- > Tangent (Stream a) = Stream (Tangent a) -- -- > Tangent (FiniteSupportStream a) = Tangent (FiniteSupportStream a) -- -- > Tangent (E2NormedVector a) = Vector (Tangent a) -- -- The `Dual` family encodes the dual space (linear functionals): -- -- > Dual Float = Float -- -- > Dual (Stream a) = FiniteSupportStream (Dual a) -- -- > Dual (FiniteSupportStream a) = Stream (Dual a) -- -- > Dual (E2NormedVector a) = Undefined -- -- In order to support differentiation over a new type that is not already -- handled by this package, one needs to define appropriate instances -- for both `Tangent` and `Dual` for that type. -- $differentiation-for-structured-types -- -- This library supports the differentiation of functions of type -- @ f :: a -> b @ -- for potentially any types @a@ and @b@. -- Thus, the derivative operator has the type: -- -- > (a -> b) -> (a -> c) -- -- The argument type @a@ is the same for both the original function -- @f :: a -> b@ and its derivative -- @f' :: a -> c@. -- However, the result type @c@ of the derivative -- depends in a non-trivial way on both @a@ and @b@. -- -- For example, the derivative of a vector-valued function of a tuple -- is a vector of tuples: -- -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- @ -- f :: (Float, Float) -> BoxedVector 3 Float -- f' :: (Float, Float) -> BoxedVector 3 (Float, Float) -- @ -- -- To illustrate the approach, consider a representative example: -- a function from a tuple to a 3D vector. -- -- >>> :{ -- sphericToVec :: (TrigField a) => (a, a) -> BoxedVector 3 a -- sphericToVec (theta, phi) = -- DVGS.fromTuple (cos theta * cos phi, cos theta * sin phi, sin theta) -- :} -- -- We will use the `customArgValDerivative` operator, which takes three arguments: -- -- 1. The argument structure descriptor — in this case, `tupleArg`, -- which is used for the @(a, a)@ input. -- -- 2. The value structure descriptor — in this case, `boxedVectorVal`, -- used for output type @BoxedVector 3 _@. -- -- 3. The function to differentiate — in this case, @sphericToVec@. -- -- The derivative is then defined as: -- -- >>> import Debug.SimpleExpr.Utils.Algebra (IntegerPower) -- >>> :{ -- sphericToVec'V1 :: (TrigField a, ExpField a, IntegerPower a, a ~ CT a) => -- (a, a) -> BoxedVector 3 (a, a) -- sphericToVec'V1 = customArgValDerivative tupleArg boxedVectorVal sphericToVec -- :} -- -- The type family `CT` and its meaning are explained in section -- [Tangent and cotangent spaces](#g:how-45-it-45-works-45-tangent-45-space). -- For now, it can be ignored. -- The types and definitions of `tupleArg` and `boxedVectorVal`, -- as well as how to construct them for other types, -- will be covered in the following sections. -- -- Alternatively, the argument and value structure can be inferred automatically -- using `autoArg` and `autoVal`: -- -- >>> :{ -- sphericToVec'V2 :: (TrigField a, ExpField a, IntegerPower a, a ~ CT a) => -- (a, a) -> BoxedVector 3 (a, a) -- sphericToVec'V2 = customArgValDerivative autoArg boxedVectorVal sphericToVec -- :} -- -- >>> :{ -- sphericToVec'V3 :: (TrigField a, ExpField a, IntegerPower a, a ~ CT a) => -- (a, a) -> BoxedVector 3 (a, a) -- sphericToVec'V3 = customArgValDerivative tupleArg autoVal sphericToVec -- :} -- -- Automatically deriving both the argument and value types -- is often problematic due to type inference limitations in Haskell. -- -- In summary, there are three common approaches to managing types -- in the derivative operator for a function @f :: a -> b@: -- -- 1. Define a derivative operator specialized for specific types @a@ and @b@. -- -- 2. Define a derivative operator that is polymorphic in the result type @b@, -- but has a fixed argument type @a@. -- See section -- [Structured Value Type](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value). -- -- 3. Define a derivative operator that is polymorphic in the argument type @a@, -- but has a fixed result type @b@. -- See `scalarValDerivative` in the subsection -- [Structured Argument Type](#g:differentiation-45-for-45-structured-45-types-45-structured-45-argument-45-type). -- $differentiation-for-structured-types-structured-value -- -- This section explains how to compute derivatives of functions whose values -- have structured types (e.g., tuples, vectors, streams, or nested combinations). -- -- We begin with -- [basic examples](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value-45-basic-45-examples) -- to demonstrate how derivatives work for common structured types. -- -- Then, in -- [custom derivative operators and value structure descriptors](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value-45-custom-45-derivative), -- we explain how to define derivative operators for any structured value type using -- custom descriptors. -- -- In -- [how it works: structured value types](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value-45-how-45-it-45-works), -- we delve into the type signatures and the underlying idea behind value type descriptors. -- -- Finally, in -- [defining custom differentiable value types](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value-45-defining-45-custom-45-value-45-type), -- we outline how to define your own differentiable types—beyond the scope -- of the built-in descriptors provided by this package. -- $differentiation-for-structured-types-structured-value-basic-examples -- -- ==== Tuple-valued function -- As a first example, we define a symbolic function @f@ of one variable -- that returns a tuple of two values. -- -- >>> :{ -- f :: TrigField a => a -> (a, a) -- f t = (cos t, sin t) -- :} -- -- Define a symbolic variable @t@, as shown in the section -- [Derivatives for Symbolic Expressions](#g:quick-45-start-45-derivatives-45-for-45-symbolic-45-expressions). -- -- >>> t = variable "t" -- >>> f t -- (cos(t),sin(t)) -- -- The simplest way to take the derivative is to use the `scalarArgDerivative` operator, -- which is polymorphic over the function's value type. -- It is a polymorphic version of `customArgValDerivative` operator considered -- in the beginning of the section -- [Differentiation for Structured Types](#g:differentiation-45-for-45-structured-45-types) -- It is defined as: -- -- > scalarArgDerivative = customArgValDerivative scalarVal autoVal -- -- The first argument `scalarVal` indicates that the function's argument type is scalar. -- The second argument `autoVal` tells the system to infer the value type automatically. -- -- The general type signature of `scalarArgDerivative` is discussed in a later section. -- In this case, it simplifies to: -- -- @ -- scalarArgDerivative :: Multiplicative (CT a) => -- (RevDiff a a -> (RevDiff a a, RevDiff a a)) -> -- a -> -- (CT a, CT a) -- @ -- -- We can now compute derivatives as follows: -- -- >>> f' = simplify . scalarArgDerivative f :: SE -> (SE, SE) -- >>> f' t -- (-(sin(t)),cos(t)) -- -- >>> f'' = simplify . scalarArgDerivative (scalarArgDerivative f) :: SE -> (SE, SE) -- >>> f'' t -- (-(cos(t)),-(sin(t))) -- -- >>> (scalarArgDerivative (scalarArgDerivative f)) t :: (SE, SE) -- (-(cos(t)*(1*1))+0,(-(sin(t))*(1*1))+0) -- -- >>> temp t = -((cos t)*(one * one)) + zero -- -- -- >>> (scalarArgDerivative temp) t :: SE -- -- >>> import Debug.SimpleExpr.Utils.Algebra ((^)) -- -- >>> (scalarArgDerivative (scalarArgDerivative (scalarArgDerivative f))) (0.0 :: Float) :: (Float, Float) -- (0.0,-1.0) -- -- >>> (scalarArgDerivative exp) t :: SE -- exp(t)*1 -- -- >>> (scalarArgDerivative (scalarArgDerivative (exp))) t :: SE -- (exp(t)*(1*1))+0 -- -- >>> f''' = simplify . scalarArgDerivative (scalarArgDerivative (scalarArgDerivative f)) :: SE -> (SE, SE) -- >>> f''' t -- (sin(t),-(cos(t))) -- -- Note that all derivateive function argiment types are the same -- as the original function, but the value typea are different. -- Here the the polymorphic preoperty of the `scalarArgDerivative` operator -- comes into play, allowing us to differentiate functions without explicit -- type annotations. -- -- ==== Vector-valued function -- In the next example, we take the derivative of a vector-valued symbolic function @v@ -- using boxed vectors from the -- [vector-sized](https://hackage.haskell.org/package/vector-sized) library. -- -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- >>> :{ -- v :: SymbolicFunc a => a -> BoxedVector 3 a -- v t = DVGS.fromTuple ( -- unarySymbolicFunc "v_x" t, -- unarySymbolicFunc "v_y" t, -- unarySymbolicFunc "v_z" t -- ) -- :} -- -- >>> v t -- Vector [v_x(t),v_y(t),v_z(t)] -- -- >>> v' = simplify . scalarArgDerivative v :: SE -> BoxedVector 3 SE -- >>> v' t -- Vector [v_x'(t),v_y'(t),v_z'(t)] -- -- ==== Stream-valued function -- Other data types, including lazy types such as streams from the -- [stream](https://hackage.haskell.org/package/stream) -- library, -- can also be differentiated. -- -- >>> :{ -- s :: SymbolicFunc a => a -> Stream a -- s t = fromList [unarySymbolicFunc ("s_" <> show n) t | n <- [0..]] -- :} -- -- >>> take 5 (s t) -- [s_0(t),s_1(t),s_2(t),s_3(t),s_4(t)] -- -- >>> :{ -- s' :: SE -> Stream SE -- s' = simplify . scalarArgDerivative s -- :} -- -- >>> take 5 (s' t) -- [s_0'(t),s_1'(t),s_2'(t),s_3'(t),s_4'(t)] -- -- ==== 4. Nested structured-valued function -- We can also differentiate functions returning values in nested types. For example: -- -- >>> :{ -- g :: SymbolicFunc a => a -> (BoxedVector 3 a, Stream a) -- g t = (v t, s t) -- :} -- -- This function has the type @a -> (BoxedVector 3 a, Stream a)@. -- Automatic differentiation remains straightforward: -- -- >>> :{ -- g' :: SE -> (BoxedVector 3 SE, Stream SE) -- g' = simplify . scalarArgDerivative g -- :} -- -- >>> fst $ g' t -- Vector [v_x'(t),v_y'(t),v_z'(t)] -- -- >>> take 5 $ snd $ g' t -- [s_0'(t),s_1'(t),s_2'(t),s_3'(t),s_4'(t)] -- $differentiation-for-structured-types-structured-value-custom-derivative -- -- Instead of the polymorphic `scalarArgDerivative` operator, -- which is defined as -- -- > scalarArgDerivative = customArgValDerivative scalarArg autoVal -- -- we can use a more specialized version tailored to the expected value type. -- These customized derivatives still use `customArgValDerivative` with a specific -- value structure descriptor but not `autoVal`. -- -- ==== Tuple-valued function -- -- Consider again the example from the previous subsection: -- -- >>> scalarTupleDerivative = customArgValDerivative scalarArg tupleVal -- -- Here, `scalarArg` indicates that -- the input of the function being differentiated -- is a scalar value, and -- `tupleVal` indicates that the output of the function being differentiated -- is a tuple of scalar values. -- -- >>> :{ -- t :: SE -- t = variable "t" -- f :: TrigField a => a -> (a, a) -- f t = (cos t, sin t) -- f' :: SE -> (SE, SE) -- f' = simplify . scalarTupleDerivative f -- :} -- -- >>> f' t -- (-(sin(t)),cos(t)) -- -- ==== Vector-valued function -- -- Similarly, we can define a derivative operator for a vector-valued function @v@: -- -- >>> scalarTupleBoxedVectorDerivative = customArgValDerivative scalarArg boxedVectorVal -- -- Here, `boxedVectorVal` declares that the function returns a boxed vector -- of scalar values. -- -- ==== Nested structured output function -- -- In the third example from the previous subsection: -- -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- >>> :{ -- g :: SymbolicFunc a => a -> (BoxedVector 3 a, Stream a) -- g = undefined -- :} -- -- the value type of @g@ is more sophisticated, so we must construct -- a custom value structure manually: -- -- >>> tupleBoxedVectorStreamVal = mkTupleVal (mkBoxedVectorVal scalarVal) (mkStreamVal scalarVal) -- >>> scalarTupleBoxedVectorStreamDerivative = customArgValDerivative scalarArg tupleBoxedVectorStreamVal -- >>> _ = scalarTupleBoxedVectorStreamDerivative g :: SE -> (BoxedVector 3 SE, Stream SE) -- -- Here: -- -- - 'mkTupleVal' constructs a value descriptor for a tuple, -- - 'mkBoxedVectorVal' constructs a value descriptor for a boxed vector, -- - 'mkStreamVal' constructs a value descriptor for a stream, -- - and 'scalarVal' denotes the scalar leaf type. -- -- In general, these building blocks combine to define custom value descriptors. -- For example: -- -- @ -- tupleVal = mkTupleVal scalarVal -- boxedVectorVal = mkBoxedVectorVal scalarVal -- streamVal = mkStreamVal scalarVal -- @ -- -- And for a scalar-valued function, we simply use: -- -- > scalarScalarDerivative = customArgValDerivative scalarArg scalarVal -- $differentiation-for-structured-types-structured-value-how-it-works -- -- This section explains how the general backpropagation mechanism operates -- at the level of function result (value) types. -- -- ==== Derivative Operator Type Signature -- -- To differentiate a scalar-to-scalar function @f :: a -> b@, we use its -- differentiable form: -- -- > f :: RevDiff a a -> RevDiff a b -- -- (see the section -- [How it works: core type `RevDiff`](#g:how-45-it-45-works-45-core-45-type-45-revdiff)). -- -- For functions returning structured values, we generalize this to: -- -- > f :: RevDiff a a -> c -- -- where @c@ is a structured result built from `RevDiff a b` values. -- We then use a \value structure descriptor\ of type @c -> d@ -- to extract the final derivative result @d@ from the structure @c@. -- -- The resulting derivative operator has the following type: -- -- @ -- scalarCustomArgDerivative :: -- (c -> d) -> -- how to extract the final output -- (RevDiff a a -> c) -> -- the differentiable function -- (a -> d) -- scalar input to final output -- scalarCustomArgDerivative = customArgValDerivative scalarArg -- @ -- -- Here, the first argument of type @c -> d@ transforms the intermediate structured result -- into the final derivative value. -- -- In fact, @scalarCustomArgDerivative@ is simply function composition: -- -- > scalarCustomArgDerivative = (.) -- -- ==== Value Descriptor Examples -- -- Common value structure descriptors include in particular: -- -- 1. /Scalar value/ -- -- > scalarVal :: Multiplicative (CT b) => RevDiff a b -> CT a -- -- Converts a single differentiable value into a scalar result. -- -- 2. /Tuple/ -- -- > tupleVal :: -- > (Multiplicative (CT b0), Multiplicative (CT b1)) => -- > (RevDiff a0 b0, RevDiff a1 b1) -> (CT a0, CT a1) -- -- Converts a tuple of differentiable values into a tuple of scalars. -- -- 3. /Boxed Vector/ -- -- > boxedVectorVal :: -- > Multiplicative (CT b) => -- > BoxedVector n (RevDiff a b) -> BoxedVector n (CT a) -- -- Converts a boxed Vector of differentiable values into a boxed Vector of scalars. -- -- 4. /Stream/ -- -- > streamVal :: -- > Multiplicative (CT b) => -- > Stream (RevDiff a b) -> Stream (CT a) -- -- Converts a stream of differentiable values into a stream of scalars. -- -- 5. /Nested structure/ -- -- For example, a function returning a tuple of a boxed vector and a stream: -- -- > tupleBoxedVectorStreamVal :: -- > Multiplicative (CT b) => -- > (BoxedVector n (RevDiff a0 b0), Stream (RevDiff a1 b1)) -> -- > (BoxedVector n (CT a), Stream (CT a)) -- -- ==== Constructing Value Descriptors -- -- You can construct value descriptors using standard higher-order functions: -- -- @ -- mkTupleVal :: (a0 -> b0) -> (a1 -> b1) -> (a0, a1) -> (b0, b1) -- mkTupleVal = cross -- -- mkBoxedVectorVal :: (a -> b) -> BoxedVector n a -> BoxedVector n b -- mkBoxedVectorVal = fmap -- -- mkStreamVal :: (a -> b) -> Stream a -> Stream b -- mkStreamVal = fmap -- @ -- -- This means that to define a derivative for any custom structured type @MyType a@, -- you only need to implement: -- -- > myTypeVal :: Multiplicative (CT b) => MyType (RevDiff a b) -> MyType (CT a) -- -- A typical approach is to define a mapping function: -- -- > mkMyTypeVal :: (a -> b) -> MyType a -> MyType b -- -- and then obtain the value descriptor by applying it to `scalarVal`: -- -- > myTypeVal = mkMyTypeVal scalarVal -- -- This approach allows you to differentiate functions returning arbitrarily -- nested combinations of types, as we did above with tuple @(,)@, -- `BoxedVector`@ n@, and `Stream`. -- $differentiation-for-structured-types-structured-value-defining-custom-value-type -- -- ==== Making Custom Scalar Type Differentiable -- -- To make a scalar type @a@ differentiable, it is necessary and sufficient to: -- -- 1. Define the type families `Tangent` for @a@ and `Dual` for `Tangent a` -- (see [Tangent and Cotangent Spaces](#g:how-45-it-45-works-45-tangent-45-space)). -- -- 2. Ensure that the type -- -- > type CT a = Dual (Tangent a) -- -- is an instance of `Multiplicative`. -- -- The second condition is required to initialize the backpropagation process -- with the value `one`. -- -- ==== Making Custom Type Constructors Differentiable -- -- To define derivatives over a custom type constructor @f :: Type -> Type@, -- the recommended approach is: -- -- 1. Define the /value descriptor/: -- -- > mkFVal :: (a -> b) -> f a -> f b -- -- In most cases, this is just `fmap`, or an optimized equivalent (see previous section). -- -- 2. Provide an instance of the @AutoDifferentiableValue@ class: -- -- > instance (AutoDifferentiableValue a b) => -- > AutoDifferentiableValue (f a) (f b) where -- > autoVal :: f a -> f b -- > autoVal = mkFVal autoVal -- -- This recursively applies `autoVal` within the structure of @f a@. -- -- For more sophisticated custom types (e.g. higher-kinded types such as -- @g :: Type -> Type -> Type@), refer to the implementation of the instance for -- tuples @(,)@ in @AutoDifferentiableValue@ for guidance. -- $differentiation-for-structured-types-structured-argument-type -- -- In this section, we consider how to differentiate a function -- with a structured or nontrivial argument type. -- -- The simplest way to compute the derivative of a scalar-valued function (i.e. gradient) -- is by using the `scalarValDerivative` operator. -- This operator is polymorphic over the function’s argument type, -- but it is restricted to functions that return scalar values. -- -- In terms of the more general `customArgValDerivative` operator -- [Differentiation for Structured Types](#g:differentiation-45-for-45-structured-45-types), -- the `scalarValDerivative` is equivalent to: -- -- > scalarValDerivative = customArgValDerivative autoArg scalarVal -- -- Here, the first argument `autoArg` indicates that the argument type -- (i.e. the structure of the input) is inferred automatically. -- -- The second argument `scalarVal` specifies that the return value of the function -- must be a scalar. -- $differentiation-for-structured-types-structured-argument-type-basic-examples -- -- ==== Gradient over the Euclidean Norm of a Vector -- Our first example involves a function over a sized boxed vector, -- `BoxedVector`. We define the squared Euclidean norm of a 3-dimensional vector: -- -- >>> import Debug.SimpleExpr.Utils.Algebra (IntegerPower, (^), MultiplicativeAction) -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- >>> :{ -- eNorm2 :: (IntegerPower a, Additive a) => BoxedVector 3 a -> a -- eNorm2 x = foldl' (+) zero (fmap (^2) x) -- :} -- -- This is not the most efficient way to define a function on large vectors, -- but for this example, we focus on type signatures and type inference -- rather than performance. -- -- The gradient of @eNorm2@ can be computed as: -- -- >>> :{ -- eNorm2' :: ( -- IntegerPower a, -- MultiplicativeAction Integer a, -- Distributive a, -- CT a ~ a -- ) => BoxedVector 3 a -> BoxedVector 3 a -- eNorm2' = scalarValDerivative eNorm2 -- :} -- -- As usual, `scalarValDerivative` can be applied to symbolic expressions, -- such as values of type `SE`: -- -- >>> x = variable "x" -- >>> y = variable "y" -- >>> z = variable "z" -- >>> r = DVGS.fromTuple (x, y, z) :: BoxedVector 3 SE -- >>> simplify $ eNorm2' r :: BoxedVector 3 SE -- Vector [2*x,2*y,2*z] -- -- It also works with numeric types like `Float`: -- -- >>> v = DVGS.fromTuple (1, 2, 3) :: BoxedVector 3 Float -- >>> eNorm2' v :: BoxedVector 3 Float -- Vector [2.0,4.0,6.0] -- -- ==== Gradient over a Stream -- The `Stream` type can also be used as an argument. -- However, note that the result of the gradient is not a `Stream`, -- but rather a bounded stream: `FiniteSupportStream`. -- See -- [Tangent and Cotangent Spaces](#g:how-45-it-45-works-45-tangent-45-space) -- for a brief explanation. -- -- Define a formal series -- -- \[ -- s = s_0, s_1, s_2, s_3, \ldots -- \] -- -- as: -- -- >>> s = fromList [variable ("s_" <> show n) | n <- [0 :: Int ..]] :: Stream SE -- -- Next, define a function that sums the first four elements of the stream: -- -- \[ -- s \mapsto s_0 + s_1 + s_2 + s_3 -- \] -- -- >>> take4Sum = NH.sum . take 4 :: Additive a => Stream a -> a -- >>> simplify $ take4Sum s :: SE -- s_0+(s_1+(s_2+s_3)) -- -- The gradient of this function can be defined as: -- -- >>> :{ -- take4Sum' :: (Distributive a, Distributive (CT a)) => -- Stream a -> FiniteSupportStream (CT a) -- take4Sum' = scalarValDerivative take4Sum -- :} -- -- >>> simplify $ take4Sum' s -- [1,1,1,1,0,0,0,... -- -- The result is a finite support stream of the form: -- -- \[ -- 1, 1, 1, 1, 0, 0, 0, \ldots -- \] -- -- as expected. -- -- ==== Gradinenet over Nested Structured Types -- -- The `scalarValDerivative` operator can also handle more complex input types. -- For example, consider a function @g@ that takes both a 3-vector and a stream: -- -- >>> :{ -- g :: (IntegerPower a, Distributive a) => -- (BoxedVector 3 a, Stream a) -> a -- g (v, s) = eNorm2 v + take4Sum s -- :} -- -- Its gradient can be computed as: -- -- >>> :{ -- g' :: (IntegerPower a, MultiplicativeAction Integer a, Distributive a, CT a ~ a) => -- (BoxedVector 3 a, Stream a) -> (BoxedVector 3 a, FiniteSupportStream a) -- g' = scalarValDerivative g -- :} -- -- Evaluating the gradient at @(r, s)@ gives: -- -- >>> simplify $ fst $ g' (r, s) :: BoxedVector 3 SE -- Vector [2*x,2*y,2*z] -- -- >>> simplify $ snd $ g' (r, s) :: FiniteSupportStream SE -- [1,1,1,1,0,0,0,... -- -- as expected. -- $differentiation-for-structured-types-structured-argument-type-custom-gradient -- -- The `scalarValDerivative` operator from the previous section is polymorphic over -- the argument type, but it works only for scalar-valued functions. -- -- In this section, we consider how to /fix the argument type/ while keeping the -- value type polymorphic. This is especially useful when computing second or higher-order -- derivatives. -- -- ==== Derivatives over a Tuple of Scalars -- We begin with a function over a tuple of two scalars, which is equivalent -- to a function of two arguments. -- -- As an example, consider the product of symbolic functions @f@ and @g@ applied -- to separate arguments: -- -- >>> :{ -- x = variable "x" -- y = variable "y" -- f :: SymbolicFunc a => a -> a -- f = unarySymbolicFunc "f" -- g :: SymbolicFunc a => a -> a -- g = unarySymbolicFunc "g" -- h :: (SymbolicFunc a, Multiplicative a) => (a, a) -> a -- h (x, y) = f x * g y -- :} -- -- Evaluating @h@ at @(x, y)@ gives: -- -- >>> h (x, y) :: SE -- f(x)*g(y) -- -- First, consider the derivative operator: -- -- >>> tupleScalarDerivative = customArgValDerivative tupleArg scalarVal -- -- It can be applied as follows: -- -- >>> h' = simplify . tupleScalarDerivative h :: (SE, SE) -> (SE, SE) -- >>> h' (x, y) -- (f'(x)*g(y),g'(y)*f(x)) -- -- However, we cannot use @tupleScalarDerivative@ to compute the second derivative of @h@, -- because it is restricted to scalar-valued functions. It is not polymorphic in the -- value type, unlike `tupleArgDerivative`. -- -- To define a version suitable for higher-order derivatives, we define: -- -- > tupleArgDerivative = customArgValDerivative tupleArg autoVal -- -- This operator is practically equivalent to `twoArgsDerivative` from the section -- [Gradient over a Two-Argument Function](#g:quick-45-start-45-function-45-of-45-two-45-argument-45-functions), -- except that it works on uncurried arguments. -- -- We can now compute the derivative of @h@ as: -- -- >>> :{ -- h' :: (SE, SE) -> (SE, SE) -- h' = simplify . tupleArgDerivative h -- :} -- -- >>> h' (x, y) -- (f'(x)*g(y),g'(y)*f(x)) -- -- Thanks to the polymorphism of `tupleArgDerivative`, we can compute higher-order -- derivatives of @h@: -- -- Second derivative: -- -- >>> :{ -- h'' :: (SE, SE) -> ((SE, SE), (SE, SE)) -- h'' = simplify . tupleArgDerivative (tupleArgDerivative h) -- :} -- -- >>> h'' (x, y) -- ((f''(x)*g(y),g'(y)*f'(x)),(f'(x)*g'(y),g''(y)*f(x))) -- -- Third derivative: -- -- >>> :{ -- h''' :: (SE, SE) -> (((SE, SE), (SE, SE)), ((SE, SE), (SE, SE))) -- h''' = simplify . tupleArgDerivative (tupleArgDerivative (tupleArgDerivative h)) -- :} -- -- >>> h''' (x, y) -- (((f'''(x)*g(y),g'(y)*f''(x)),(f''(x)*g'(y),g''(y)*f'(x))),((f''(x)*g'(y),g''(y)*f'(x)),(f'(x)*g''(y),g'''(y)*f(x)))) -- -- ==== Derivatives over Boxed Vectors -- The next example demonstrates derivatives over boxed vectors. -- -- > boxedVectorArgDerivative = customArgValDerivative boxedVectorArg autoVal -- -- Recall the function @eNorm2@, which computes the squared Euclidean norm -- of a 3-dimensional vector: -- -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- >>> :{ -- eNorm2 :: Distributive a => BoxedVector 3 a -> a -- eNorm2 x = foldl' (+) zero (x * x) -- :} -- -- We apply `boxedVectorArgDerivative` as follows: -- -- >>> v = DVGS.fromTuple (1, 2, 3) :: BoxedVector 3 Float -- >>> boxedVectorArgDerivative eNorm2 v :: BoxedVector 3 Float -- Vector [2.0,4.0,6.0] -- -- The second derivative gives the Hessian matrix represented here -- as a boxed Vector of boxed Vectors: -- -- >>> boxedVectorArgDerivative (boxedVectorArgDerivative eNorm2) v :: BoxedVector 3 (BoxedVector 3 Float) -- Vector [Vector [2.0,0.0,0.0],Vector [0.0,2.0,0.0],Vector [0.0,0.0,2.0]] -- -- The third derivative is a rank-3 tensor filled with zeros: -- -- >>> boxedVectorArgDerivative (boxedVectorArgDerivative (boxedVectorArgDerivative eNorm2)) v :: BoxedVector 3 (BoxedVector 3 (BoxedVector 3 Float)) -- Vector [Vector [Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0]],Vector [Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0]],Vector [Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0],Vector [0.0,0.0,0.0]]] -- $differentiation-for-structured-types-structured-argument-type-how-it-works -- -- In order to compute a derivative, we need a function with the following signature: -- -- > f :: RevDiff a a -> RevDiff a b -- -- (See section -- [Core type: RevDiff](#g:how-45-it-45-works-45-core-45-type-45-RevDiff).) -- -- Suppose we want to differentiate a scalar-valued function of a tuple @(a, b)@: -- -- > f :: (a, b) -> c -- -- Our strategy is to exploit the polymorphism of @f@ -- with respect to the types @a@ and @b@. -- This means that @f@ must also support the type: -- -- > f :: (RevDiff t a, RevDiff t b) -> RevDiff t c -- -- To differentiate such a function, we need a way to transform a single input of type -- @RevDiff a (b0, b1)@ into a pair of inputs @(RevDiff a b0, RevDiff a b1)@. -- -- This is exactly the role of the /argument structure derscriptor/: -- -- > tupleArg :: (Additive (CT b0), Additive (CT b1)) => -- > RevDiff a (b0, b1) -> (RevDiff a b0, RevDiff a b1) -- -- Using this, we can define a new function: -- -- > tupleArg . f :: RevDiff a (b0, b1) -> RevDiff a c -- -- and apply `simpleDerivative`: -- -- > simpleDerivative (tupleArg . f) :: (b0, b1) -> (CT b0, CT b1) -- -- More generally, the expression: -- -- > customArgValDerivative arg scalarVal f -- -- is equivalent to: -- -- > simpleDerivative (arg . f) -- -- Similarly, we can define argument structure descriptor for Vectors and streams: -- -- > boxedVectorArg :: (Additive (CT b), KnownNat n) => -- > RevDiff a (BoxedVector n b) -> BoxedVector n (RevDiff a b) -- -- > streamArg :: Additive (CT b) => -- > RevDiff a (Stream b) -> Stream (RevDiff a b) -- -- We can also combine them for more complex structured arguments. -- For example: -- -- >>> import Numeric.InfBackprop.Utils.SizedVector (BoxedVector) -- -- >>> :{ -- tupleBoxedVectorStreamArg :: (Additive b, Additive c, KnownNat n) => -- RevDiff a (BoxedVector n b, FiniteSupportStream c) (BoxedVector n d, Stream e) -> (BoxedVector n (RevDiff a b d), Stream (RevDiff a c e)) -- tupleBoxedVectorStreamArg = cross boxedVectorArg streamArg . tupleArg -- :} -- -- This allows us to differentiate functions whose arguments have a nested structure, -- such as @(BoxedVector n a, Stream a)@. -- -- Alternatively, we can construct argument structure terms using the same style as for -- value structure terms (see -- [How it Works: Structured Value Types](#g:differentiation-45-for-45-structured-45-types-45-structured-45-value-45-how-45-it-45-works)): -- -- >>> :{ -- tupleBoxedVectorStreamArgV2 :: (Additive b, Additive c, KnownNat n) => -- RevDiff a (BoxedVector n b, FiniteSupportStream c) (BoxedVector n d, Stream e) -> (BoxedVector n (RevDiff a b d), Stream (RevDiff a c e)) -- tupleBoxedVectorStreamArgV2 = mkTupleArg (mkBoxedVectorArg id) (mkStreamArg id) -- :} -- -- Note that: -- -- > tupleArg = mkTupleArg id -- > boxedVectorArg = mkBoxedVectorArg id -- > streamArg = mkStreamArg id -- -- where `id` is used for scalar arguments. -- $differentiation-for-structured-types-structured-argument-type-defining-custom-type -- -- To support differentiation with respect to a custom scalar type @a@, -- it is sufficient to define the associated type families: -- -- - `Tangent`@ a@ -- - `Dual`@(@`Tangent`@a)@ (we denote this as `CT`@a@) -- -- (See -- [Tangent and Cotangent Spaces](#g:how-45-it-45-works-45-tangent-45-space) -- for more details.) -- -- Of course, you must also implement some differentiable function, -- which is to be differentiated, for example: -- -- > func :: RevDiff a a -> RevDiff a b -- -- If @b@ is a scalar type, -- the derivative will have the type: -- -- > func' :: a -> CT a -- -- For structured types like @f :: Type -> Type@, we recommend the following: -- -- 1. Define the type families: -- -- - `Tangent`@(f a)@ -- -- - `Dual`@ (@'Tangent`@ (f a))@ -- -- 2. Define the argument type descriptor, which is practically a permutation function: -- -- > fArg :: Additive (CT b) => -- > RevDiff a (f b) -> f (RevDiff a b) -- -- 3. Define the argument type descriptor constructor: -- -- > mkFArg :: Additive (CT b) => -- > (RevDiff a b -> c) -> RevDiff a (f b) -> f c -- -- 4. Provide an instance: -- -- > instance (AutoDifferentiableArgument a b c, Additive (CT b)) => -- > AutoDifferentiableArgument a (f b) (f c) where -- > autoArg = mkFArg autoArg -- -- For bifunctor types @g :: Type -> Type -> Type@: -- -- 1. Define type families: -- -- - `Tangent`@(g a0 a1)@ -- -- - `Dual`@(@`Tangent`@(g a0 a1))@ -- -- 2. Define the argument type descriptor: -- -- > gArg :: (Additive (CT b0), Additive (CT b1)) => -- > RevDiff a (g b0 b1) -> g (RevDiff a b0) (RevDiff a b1) -- -- 3. Define the argument type descriptor constructor: -- -- > mkGArg :: (Additive (CT b0), Additive (CT b1)) => -- > (RevDiff a b0 -> c0) -> -- > (RevDiff a b1 -> c1) -> -- > RevDiff a (g b0 b1) -> -- > g c0 c1 -- -- 4. Provide an instance: -- -- > instance ( -- > AutoDifferentiableArgument a b0 c0, -- > AutoDifferentiableArgument a b1 c1, -- > Additive (CT b0), -- > Additive (CT b1) -- > ) => -- > AutoDifferentiableArgument a (g b0 b1) (g c0 c1) where -- > autoArg = mkGArg autoArg autoArg -- $performance-remarks -- -- This section discusses performance considerations when using the library. -- $performance-remarks-subexpression-elimination -- -- Some intermediate results computed during the forward pass -- (see -- [The Backpropagation Derivative](#g:how-45-it-45-works-45-backpropagation)) -- can be reused during the backward pass. -- For deep neural networks, this reuse can result in significant computational savings. -- This optimization can be viewed as a form of /subexpression elimination/— -- a problem that Haskell’s evaluation model doesn't always handle automatically. -- -- Consider the following example: -- -- >>> :{ -- f, g, h :: SymbolicFunc a => a -> a -- f = unarySymbolicFunc "f" -- g = unarySymbolicFunc "g" -- h = unarySymbolicFunc "h" -- k :: BinarySymbolicFunc a => a -> a -> a -- k = binarySymbolicFunc "k" -- forwardV1 :: (SymbolicFunc a, BinarySymbolicFunc a, Additive a) => a -> a -- forwardV1 x_ = k (g y) (h y) where y = f x_ -- :} -- -- Here we define a function @forwardV1@ as a composition -- of functions. The intermediate result @f x@ is bound to a variable @y@, -- which is then passed to both @g@ and @h@. -- -- To trace the evaluation of functions @f@, @g@, @h@, and @k@, -- we use the `trace` function from @Debug.Trace@. -- To facilitate this, we define a traced version @Traced@ -- of the symbolic expression type @SE@: -- -- >>> x = MkTraced $ variable "x" :: Traced SE -- -- For example: -- -- >>> f x :: Traced SE -- <<< TRACING: Calculating f of x >>> -- f(x) -- -- The output: -- -- > <<< TRACING: Calculating f of x >>> -- -- is produced by the `trace` mechanism. -- -- Now consider the more complex function: -- -- > >>> simplify $ forwardV1 x :: Traced SimpleExpr -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating g of f(x) >>> -- > <<< TRACING: Calculating h of f(x) >>> -- > <<< TRACING: Calculating k of g(f(x)) and h(f(x)) >>> -- > k(g(f(x)),h(f(x))) -- -- The output may vary in order, depending on GHC's optimizations, but importantly, -- note that @f x@ is only computed once and its result is reused, -- thanks to the local binding. -- -- By contrast, if we define @forwardV2@ -- without explicitly factoring out the shared subexpression: -- -- >>> :{ -- forwardV2 :: (SymbolicFunc a, BinarySymbolicFunc a, Additive a) => a -> a -- forwardV2 x_ = k (g (f x_)) (h (f x_)) -- :} -- -- the tracing output will show redundant evaluations: -- -- > >>> simplify $ forwardV2 x :: Traced SimpleExpr -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating g of f(x) >>> -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating h of f(x) >>> -- > <<< TRACING: Calculating k of g(f(x)) and h(f(x)) >>> -- > k(g(f(x)),h(f(x))) -- -- Here, @f x@ is computed twice. -- This illustrates that /GHC does not always automatically eliminate subexpressions/. -- -- Now consider tracing the derivative of @forwardV1@. -- In the long output below, observe that @f'@ -- is /not/ computed twice during the backward pass: -- -- > >>> simplify $ simpleDerivative forwardV1 x :: Traced SimpleExpr -- > <<< TRACING: Calculating f' of x >>> -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating g' of f(x) >>> -- > <<< TRACING: Calculating g of f(x) >>> -- > <<< TRACING: Calculating h of f(x) >>> -- > <<< TRACING: Calculating k'_1 of g(f(x)) and h(f(x)) >>> -- > <<< TRACING: Calculating (*) of k'_1(g(f(x)),h(f(x))) and 1 >>> -- > <<< TRACING: Calculating (*) of g'(f(x)) and k'_1(g(f(x)),h(f(x)))*1 >>> -- > <<< TRACING: Calculating (*) of f'(x) and g'(f(x))*(k'_1(g(f(x)),h(f(x)))*1) >>> -- > <<< TRACING: Calculating h' of f(x) >>> -- > <<< TRACING: Calculating k'_2 of g(f(x)) and h(f(x)) >>> -- > <<< TRACING: Calculating (*) of k'_2(g(f(x)),h(f(x))) and 1 >>> -- > <<< TRACING: Calculating (*) of h'(f(x)) and k'_2(g(f(x)),h(f(x)))*1 >>> -- > <<< TRACING: Calculating (*) of f'(x) and h'(f(x))*(k'_2(g(f(x)),h(f(x)))*1) >>> -- > <<< TRACING: Calculating (+) of f'(x)*(g'(f(x))*(k'_1(g(f(x)),h(f(x)))*1)) and f'(x)*(h'(f(x))*(k'_2(g(f(x)),h(f(x)))*1)) >>> -- > (f'(x)*(g'(f(x))*k'_1(g(f(x)),h(f(x)))))+(f'(x)*(h'(f(x))*k'_2(g(f(x)),h(f(x))))) -- -- The possible duplication of becomes more severe as function composition grows deeper— -- a major performance issue in neural network applications. -- -- For further illustration, consider the first and second derivatives -- of the composition @(g . f)@: -- -- > >>> simpleDerivative (g . f) x :: Traced SimpleExpr -- > <<< TRACING: Calculating f' of x >>> -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating g' of f(x) >>> -- > <<< TRACING: Calculating (*) of g'(f(x)) and 1 >>> -- > <<< TRACING: Calculating (*) of f'(x) and g'(f(x))*1 >>> -- > f'(x)*(g'(f(x))*1) -- -- > >>> simpleDerivative (simpleDerivative (g . f)) x :: Traced SimpleExpr -- > <<< TRACING: Calculating f'' of x >>> -- > <<< TRACING: Calculating f of x >>> -- > <<< TRACING: Calculating g' of f(x) >>> -- > <<< TRACING: Calculating (*) of g'(f(x)) and 1 >>> -- > <<< TRACING: Calculating (*) of g'(f(x))*1 and 1 >>> -- > <<< TRACING: Calculating (*) of f''(x) and (g'(f(x))*1)*1 >>> -- > <<< TRACING: Calculating f' of x >>> -- > <<< TRACING: Calculating g'' of f(x) >>> -- > <<< TRACING: Calculating f' of x >>> -- > <<< TRACING: Calculating (*) of f'(x) and 1 >>> -- > <<< TRACING: Calculating (*) of 1 and f'(x)*1 >>> -- > <<< TRACING: Calculating (*) of g''(f(x)) and 1*(f'(x)*1) >>> -- > <<< TRACING: Calculating (*) of f'(x) and g''(f(x))*(1*(f'(x)*1)) >>> -- > <<< TRACING: Calculating (+) of f'(x)*(g''(f(x))*(1*(f'(x)*1))) and 0 >>> -- > <<< TRACING: Calculating (+) of f''(x)*((g'(f(x))*1)*1) and (f'(x)*(g''(f(x))*(1*(f'(x)*1))))+0 >>> -- > (f''(x)*((g'(f(x))*1)*1))+((f'(x)*(g''(f(x))*(1*(f'(x)*1))))+0) -- -- Here we observe that @f'(x)@ is computed /twice/ in the second derivative. -- This occurs because it appears in two different branches of the expression tree: -- once as the outer derivative, and once via the inner term @g'(f(x))@. -- -- Unfortunately, the current implementation of `simplify` is /not able/ to eliminate -- this redundancy, as it lacks full common subexpression elimination. -- -- Nevertheless, for typical neural network applications, -- the current backpropagation implementation for the first derivative -- is performant enough in practice. -- $performance-remarks-forward-step-results-reusage -- -- Some results from the forward pass can be reused during the backward pass, -- leading to significant computational savings. Let's explore this optimization -- through a concrete example. -- -- Consider differentiating the hyperbolic functions: -- -- \[ -- \cosh x = \sinh' x = \frac{e^x + e^{-x}}{2} -- \] -- and -- \[ -- \sinh x = \cosh' x = \frac{e^x - e^{-x}}{2} -- \] -- -- Notice that both functions require computing the same exponentials: -- \(e^x\) and \(e^{-x}\). -- During the forward pass, we calculate these exponentials to compute the function value. -- Then, during the backward pass for derivative computation, we need exactly the same -- exponentials again. Rather than recomputing them, we can reuse the forward pass results. -- -- This optimization becomes particularly valuable when dealing with computationally -- expensive operations, such as matrix exponentials, where avoiding redundant -- calculations can dramatically improve performance. -- -- While automatic subexpression elimination techniques exist, we'll explore a different -- approach: manual subexpression elimination implemented directly in the backpropagation -- definition. This gives us explicit control over which intermediate results to preserve -- and reuse. -- -- Here's how we implement this optimization: -- -- We define an @ExpFieldV2@ typeclass that produces the same function values as -- `ExpField` -- but differs in how it handles intermediate computations, specifically designed to -- enable result reuse: -- -- >>> :{ -- class ExpFieldV2 a where -- expV2 :: a -> a -- sinhV2 :: a -> a -- coshV2 :: a -> a -- instance ExpFieldV2 SE where -- expV2 = exp -- sinhV2 x_ = (exp x_ - exp (negate x_)) / number 2 -- coshV2 x_ = (exp x_ + exp (negate x_)) / number 2 -- instance (ExpFieldV2 a, Distributive a, Subtractive a, Divisive a, FromInteger a) => -- ExpFieldV2 (RevDiff t a a) where -- expV2 = simpleDifferentiableFunc expV2 expV2 -- sinhV2 (MkRevDiff x bpc) = -- MkRevDiff ((expP - expM) NH./ fromInteger 2) (bpc . ((expP + expM) *)) where -- expP = expV2 x -- expM = expV2 (negate x) -- coshV2 (MkRevDiff x bpc) = -- MkRevDiff ((expP + expM) NH./ fromInteger 2) (bpc . ((expP - expM) *)) where -- expP = expV2 x -- expM = expV2 (negate x) -- instance (ExpFieldV2 a, ExpField a, FromInteger a, Show a) => -- ExpFieldV2 (Traced a) where -- expV2 = addTraceUnary "exp" expV2 -- sinhV2 x_ = (expV2 x_ - expV2 (negate x_)) / fromInteger 2 -- coshV2 x_ = (expV2 x_ + expV2 (negate x_)) / fromInteger 2 -- :} -- -- The key insight is in the RevDiff instance: we manually store the exponentials -- @expP@ (for \(e^x\)) and @expM@ (for \(e^{-x}\)) as local bindings. -- This ensures they're -- computed only once and then reused both for the forward value calculation and -- the backward pass derivative computation. -- -- Let's verify this optimization works as expected by tracing the computations: -- -- >>> x = MkTraced $ variable "x" :: Traced SE -- -- > >>> coshV2 x -- > <<< TRACING: Calculating exp of x >>> -- > <<< TRACING: Calculating negate of x >>> -- > <<< TRACING: Calculating exp of -(x) >>> -- > <<< TRACING: Calculating (+) of exp(x) and exp(-(x)) >>> -- > <<< TRACING: Calculating (/) of exp(x)+exp(-(x)) and 2 >>> -- > (exp(x)+exp(-(x)))/2 -- -- Now let's examine what happens when we compute both the value and derivative. -- To this end, we use a function `simpleValueAndDerivative` -- that computes both the value and derivative: -- -- > >>> simpleValueAndDerivative coshV2 x :: (Traced SE, Traced SE) -- > ( <<< TRACING: Calculating exp of x >>> -- > <<< TRACING: Calculating negate of x >>> -- > <<< TRACING: Calculating exp of -(x) >>> -- > <<< TRACING: Calculating (+) of exp(x) and exp(-(x)) >>> -- > <<< TRACING: Calculating (/) of exp(x)+exp(-(x)) and 2 >>> -- > (exp(x)+exp(-(x)))/2, <<< TRACING: Calculating (-) of exp(x) and exp(-(x)) >>> -- > <<< TRACING: Calculating (*) of exp(x)-exp(-(x)) and 1 >>> -- > (exp(x)-exp(-(x)))*1) -- -- Notice how the exponential calculations (exp of x and exp of -(x)) appear only -- once in the trace, even though they're used in both the forward and backward passes. -- This demonstrates that our manual subexpression elimination successfully avoids -- redundant computations, reusing the exponential results as intended. -- -- Moreover, we can compute the second derivative without recomputing the exponentials: -- -- > >>> simpleDerivative (simpleDerivative coshV2) x :: Traced SE -- > <<< TRACING: Calculating exp of x >>> -- > <<< TRACING: Calculating (*) of 1 and 1 >>> -- > <<< TRACING: Calculating (*) of exp(x) and 1*1 >>> -- > <<< TRACING: Calculating negate of x >>> -- > <<< TRACING: Calculating exp of -(x) >>> -- > <<< TRACING: Calculating negate of 1*1 >>> -- > <<< TRACING: Calculating (*) of exp(-(x)) and -(1*1) >>> -- > <<< TRACING: Calculating negate of exp(-(x))*-(1*1) >>> -- > <<< TRACING: Calculating (+) of exp(x)*(1*1) and -(exp(-(x))*-(1*1)) >>> -- > <<< TRACING: Calculating (+) of (exp(x)*(1*1))+-(exp(-(x))*-(1*1)) and 0 >>> -- > ((exp(x)*(1*1))+-(exp(-(x))*-(1*1)))+0 -- $what-is-next -- -- Unboxed vectors and tensors are not currently supported in the library.