{-# LANGUAGE ScopedTypeVariables #-}

module SoPSat.Internal.NewtonsMethod
where

import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (fromJust)

import SoPSat.Internal.SoP (
  Atom (..),
  Product (..),
  SoP (..),
  Symbol (..),
 )
import SoPSat.SoP (atoms)

-- | Evaluates SoP given atom bindings
evalSoP ::
  (Ord f, Ord c, Floating n) =>
  -- | Expression to evaluate
  SoP f c ->
  -- | Bindings from atoms to values
  Map (Atom f c) n ->
  -- | Evaluation result
  n
evalSoP :: forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Map (Atom f c) n -> n
evalSoP (S []) Map (Atom f c) n
_ = n
0
evalSoP (S [Product f c]
ps) Map (Atom f c) n
binds = [n] -> n
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([n] -> n) -> [n] -> n
forall a b. (a -> b) -> a -> b
$ (Product f c -> n) -> [Product f c] -> [n]
forall a b. (a -> b) -> [a] -> [b]
map (Product f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Map (Atom f c) n -> n
`evalProduct` Map (Atom f c) n
binds) [Product f c]
ps

{- | Evaluates product given atom bindings

Used by @evalSoP@
-}
evalProduct ::
  (Ord f, Ord c, Floating n) =>
  -- | Product to evalute
  Product f c ->
  -- | Atom bindings
  Map (Atom f c) n ->
  -- | Evaluation results
  n
evalProduct :: forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Map (Atom f c) n -> n
evalProduct (P [Symbol f c]
ss) Map (Atom f c) n
binds = [n] -> n
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([n] -> n) -> [n] -> n
forall a b. (a -> b) -> a -> b
$ (Symbol f c -> n) -> [Symbol f c] -> [n]
forall a b. (a -> b) -> [a] -> [b]
map (Symbol f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Map (Atom f c) n -> n
`evalSymbol` Map (Atom f c) n
binds) [Symbol f c]
ss

{- | Evaluates symbol given atom bindings

Used by @evalProduct@
-}
evalSymbol ::
  (Ord f, Ord c, Floating n) =>
  -- | Symbol to evaluate
  Symbol f c ->
  -- | Atom bindings
  Map (Atom f c) n ->
  -- | Evaluation result
  n
evalSymbol :: forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Map (Atom f c) n -> n
evalSymbol (I Integer
i) Map (Atom f c) n
_ = Integer -> n
forall a. Num a => Integer -> a
fromInteger Integer
i
evalSymbol (A Atom f c
a) Map (Atom f c) n
binds = Maybe n -> n
forall {a}. Num a => Maybe a -> a
f (Maybe n -> n) -> Maybe n -> n
forall a b. (a -> b) -> a -> b
$ Atom f c -> Map (Atom f c) n -> Maybe n
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Atom f c
a Map (Atom f c) n
binds
 where
  f :: Maybe a -> a
f (Just a
n) = a
n
  f Maybe a
Nothing = a
0
evalSymbol (E SoP f c
b Product f c
p) Map (Atom f c) n
binds = n -> n
forall a. Floating a => a -> a
exp (Product f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Map (Atom f c) n -> n
evalProduct Product f c
p Map (Atom f c) n
binds n -> n -> n
forall a. Num a => a -> a -> a
* n -> n
forall a. Floating a => a -> a
log (SoP f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Map (Atom f c) n -> n
evalSoP SoP f c
b Map (Atom f c) n
binds))

{- | Analitically computes derivative of an expression
with respect to an atom

Returns function similar to @evalSoP@
-}
derivative ::
  (Ord f, Ord c, Floating n) =>
  -- | Expression to take a derivative of
  SoP f c ->
  -- | Atom to take a derivetive with respect to
  Atom f c ->
  -- | Function from bindings, representing point,
  -- to value of the derivative at that point
  (Map (Atom f c) n -> n)
derivative :: forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Atom f c -> Map (Atom f c) n -> n
derivative SoP f c
sop Atom f c
symb = \Map (Atom f c) n
binds -> [n] -> n
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([n] -> n) -> [n] -> n
forall a b. (a -> b) -> a -> b
$ [Map (Atom f c) n -> n]
d [Map (Atom f c) n -> n] -> [Map (Atom f c) n] -> [n]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Map (Atom f c) n
binds]
 where
  d :: [Map (Atom f c) n -> n]
d = (Product f c -> Map (Atom f c) n -> n)
-> [Product f c] -> [Map (Atom f c) n -> n]
forall a b. (a -> b) -> [a] -> [b]
map (Product f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Atom f c -> Map (Atom f c) n -> n
`derivativeProduct` Atom f c
symb) ([Product f c] -> [Map (Atom f c) n -> n])
-> [Product f c] -> [Map (Atom f c) n -> n]
forall a b. (a -> b) -> a -> b
$ SoP f c -> [Product f c]
forall f c. SoP f c -> [Product f c]
unS SoP f c
sop

{- | Analitically computes derivative of a product
with respect to an atom

Used by @derivative@
-}
derivativeProduct ::
  (Ord f, Ord c, Floating n) =>
  -- | Product to take a derivative of
  Product f c ->
  -- | Atom to take a derivative with respect to
  Atom f c ->
  -- | Function from bindings to a value
  (Map (Atom f c) n -> n)
derivativeProduct :: forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Atom f c -> Map (Atom f c) n -> n
derivativeProduct (P []) Atom f c
_ = n -> Map (Atom f c) n -> n
forall a b. a -> b -> a
const n
0
derivativeProduct (P (Symbol f c
s : [Symbol f c]
ss)) Atom f c
symb = \Map (Atom f c) n
binds ->
  Symbol f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Atom f c -> Map (Atom f c) n -> n
derivativeSymbol Symbol f c
s Atom f c
symb Map (Atom f c) n
binds n -> n -> n
forall a. Num a => a -> a -> a
* Product f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Map (Atom f c) n -> n
evalProduct Product f c
ps Map (Atom f c) n
binds
    n -> n -> n
forall a. Num a => a -> a -> a
+ Symbol f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Map (Atom f c) n -> n
evalSymbol Symbol f c
s Map (Atom f c) n
binds n -> n -> n
forall a. Num a => a -> a -> a
* Product f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Atom f c -> Map (Atom f c) n -> n
derivativeProduct Product f c
ps Atom f c
symb Map (Atom f c) n
binds
 where
  ps :: Product f c
ps = [Symbol f c] -> Product f c
forall f c. [Symbol f c] -> Product f c
P [Symbol f c]
ss

{- | Analitically computes derivative of a symbol
with respect to an atom

Used by @derivativeProduct@
-}
derivativeSymbol ::
  (Ord f, Ord c, Floating n) =>
  -- | Symbol to take a derivate of
  Symbol f c ->
  -- | Atom to take a derivate with respect to
  Atom f c ->
  -- | Function from bindings to a value
  (Map (Atom f c) n -> n)
derivativeSymbol :: forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Atom f c -> Map (Atom f c) n -> n
derivativeSymbol (I Integer
_) Atom f c
_ = n -> Map (Atom f c) n -> n
forall a b. a -> b -> a
const n
0
derivativeSymbol (A Atom f c
a) Atom f c
atom
  | Atom f c
a Atom f c -> Atom f c -> Bool
forall a. Eq a => a -> a -> Bool
== Atom f c
atom = n -> Map (Atom f c) n -> n
forall a b. a -> b -> a
const n
1
  | Bool
otherwise = n -> Map (Atom f c) n -> n
forall a b. a -> b -> a
const n
0
derivativeSymbol e :: Symbol f c
e@(E SoP f c
b Product f c
p) Atom f c
atom = \Map (Atom f c) n
binds ->
  Map (Atom f c) n -> n
expExpr Map (Atom f c) n
binds
    n -> n -> n
forall a. Num a => a -> a -> a
* ( SoP f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Atom f c -> Map (Atom f c) n -> n
derivative SoP f c
b Atom f c
atom Map (Atom f c) n
binds
          n -> n -> n
forall a. Num a => a -> a -> a
* Product f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Map (Atom f c) n -> n
evalProduct Product f c
p Map (Atom f c) n
binds
          n -> n -> n
forall a. Fractional a => a -> a -> a
/ SoP f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Map (Atom f c) n -> n
evalSoP SoP f c
b Map (Atom f c) n
binds
          n -> n -> n
forall a. Num a => a -> a -> a
+ Map (Atom f c) n -> n
logExpr Map (Atom f c) n
binds
            n -> n -> n
forall a. Num a => a -> a -> a
* Product f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Product f c -> Atom f c -> Map (Atom f c) n -> n
derivativeProduct Product f c
p Atom f c
atom Map (Atom f c) n
binds
      )
 where
  expExpr :: Map (Atom f c) n -> n
expExpr = Symbol f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
Symbol f c -> Map (Atom f c) n -> n
evalSymbol Symbol f c
e
  logExpr :: Map (Atom f c) n -> n
logExpr = n -> n
forall a. Floating a => a -> a
log (n -> n) -> (Map (Atom f c) n -> n) -> Map (Atom f c) n -> n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SoP f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Map (Atom f c) n -> n
evalSoP SoP f c
b

-- | Finds if an expression can be equal to zero
newtonMethod ::
  forall f c n.
  (Ord f, Ord c, Ord n, Floating n) =>
  -- | Expression to check
  SoP f c ->
  -- | @Right binds@ - Atom bindings when expression is equal to zero
  --   @Left binds@ - Last checked bindings
  Either (Map (Atom f c) n) (Map (Atom f c) n)
newtonMethod :: forall f c n.
(Ord f, Ord c, Ord n, Floating n) =>
SoP f c -> Either (Map (Atom f c) n) (Map (Atom f c) n)
newtonMethod SoP f c
sop = Map (Atom f c) n
-> Word -> Either (Map (Atom f c) n) (Map (Atom f c) n)
go Map (Atom f c) n
init_guess Word
steps
 where
  consts :: Set (Atom f c)
consts = SoP f c -> Set (Atom f c)
forall f c. (Ord f, Ord c) => SoP f c -> Set (Atom f c)
atoms SoP f c
sop
  derivs :: Map (Atom f c) (Map (Atom f c) n -> n)
derivs = (Atom f c -> Map (Atom f c) n -> n)
-> Set (Atom f c) -> Map (Atom f c) (Map (Atom f c) n -> n)
forall k a. (k -> a) -> Set k -> Map k a
M.fromSet (SoP f c -> Atom f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Atom f c -> Map (Atom f c) n -> n
derivative SoP f c
sop) Set (Atom f c)
consts
  init_guess :: Map (Atom f c) n
init_guess = (Atom f c -> n) -> Set (Atom f c) -> Map (Atom f c) n
forall k a. (k -> a) -> Set k -> Map k a
M.fromSet (n -> Atom f c -> n
forall a b. a -> b -> a
const n
10) Set (Atom f c)
consts
  steps :: Word
steps = Word
40

  go :: Map (Atom f c) n -> Word -> Either (Map (Atom f c) n) (Map (Atom f c) n)
  go :: Map (Atom f c) n
-> Word -> Either (Map (Atom f c) n) (Map (Atom f c) n)
go Map (Atom f c) n
guess Word
0 = Map (Atom f c) n -> Either (Map (Atom f c) n) (Map (Atom f c) n)
forall a b. a -> Either a b
Left Map (Atom f c) n
guess
  go Map (Atom f c) n
guess Word
n
    | n
val n -> n -> Bool
forall a. Ord a => a -> a -> Bool
<= n
0.1 = Map (Atom f c) n -> Either (Map (Atom f c) n) (Map (Atom f c) n)
forall a b. b -> Either a b
Right Map (Atom f c) n
guess
    | Bool
otherwise =
        let
          new_guess :: Map (Atom f c) n
new_guess = (Map (Atom f c) n -> (Atom f c, n) -> Map (Atom f c) n)
-> Map (Atom f c) n -> [(Atom f c, n)] -> Map (Atom f c) n
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Map (Atom f c) n
binds (Atom f c
c, n
x) -> Atom f c -> n -> Map (Atom f c) n -> Map (Atom f c) n
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Atom f c
c (n
x n -> n -> n
forall a. Num a => a -> a -> a
- n
val n -> n -> n
forall a. Fractional a => a -> a -> a
/ Atom f c -> n
dsdc Atom f c
c) Map (Atom f c) n
binds) Map (Atom f c) n
guess ([(Atom f c, n)] -> Map (Atom f c) n)
-> [(Atom f c, n)] -> Map (Atom f c) n
forall a b. (a -> b) -> a -> b
$ Map (Atom f c) n -> [(Atom f c, n)]
forall k a. Map k a -> [(k, a)]
M.toList Map (Atom f c) n
guess
         in
          Map (Atom f c) n
-> Word -> Either (Map (Atom f c) n) (Map (Atom f c) n)
go Map (Atom f c) n
new_guess (Word
n Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1)
   where
    val :: n
val = SoP f c -> Map (Atom f c) n -> n
forall f c n.
(Ord f, Ord c, Floating n) =>
SoP f c -> Map (Atom f c) n -> n
evalSoP SoP f c
sop Map (Atom f c) n
guess
    dsdc :: Atom f c -> n
dsdc Atom f c
c = Maybe (Map (Atom f c) n -> n) -> Map (Atom f c) n -> n
forall a. HasCallStack => Maybe a -> a
fromJust (Atom f c
-> Map (Atom f c) (Map (Atom f c) n -> n)
-> Maybe (Map (Atom f c) n -> n)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Atom f c
c Map (Atom f c) (Map (Atom f c) n -> n)
derivs) Map (Atom f c) n
guess