module Numeric.InfBackprop.Utils.Vector
( fromTuple,
safeHead,
safeLast,
trimArrayHead,
trimArrayTail,
zipWith,
)
where
import Control.Monad (MonadPlus, mzero)
import Data.Bool (otherwise)
import Data.Eq (Eq, (==))
import Data.Function (($))
import qualified Data.IndexedListLiterals as DILL
import Data.Maybe (Maybe (Just, Nothing))
import Data.Ord (Ordering (EQ, GT, LT), compare)
import qualified Data.Vector.Generic as DVG
import GHC.Base (pure, (.))
fromTuple ::
(DVG.Vector v a) =>
(DILL.IndexedListLiterals input length a) =>
input ->
v a
fromTuple :: forall (v :: * -> *) a input (length :: Nat).
(Vector v a, IndexedListLiterals input length a) =>
input -> v a
fromTuple = [a] -> v a
forall (v :: * -> *) a. Vector v a => [a] -> v a
DVG.fromList ([a] -> v a) -> (input -> [a]) -> input -> v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. input -> [a]
forall input (length :: Nat) output.
IndexedListLiterals input length output =>
input -> [output]
DILL.toList
safeHead :: (DVG.Vector v a, MonadPlus m) => v a -> m a
safeHead :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeHead v a
vec
| v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
DVG.null v a
vec = m a
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
| Bool
otherwise = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
DVG.unsafeHead v a
vec
safeLast :: (DVG.Vector v a, MonadPlus m) => v a -> m a
safeLast :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeLast v a
vec
| v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
DVG.null v a
vec = m a
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
| Bool
otherwise = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
DVG.unsafeLast v a
vec
trimArrayHead :: (DVG.Vector v a, Eq a) => a -> v a -> v a
trimArrayHead :: forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayHead a
x v a
vec = case v a -> Maybe a
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeHead v a
vec of
Maybe a
Nothing -> v a
forall (v :: * -> *) a. Vector v a => v a
DVG.empty
Just a
firstVal ->
if a
firstVal a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
then a -> v a -> v a
forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayHead a
x (v a -> v a
forall (v :: * -> *) a. Vector v a => v a -> v a
DVG.tail v a
vec)
else v a
vec
trimArrayTail :: (DVG.Vector v a, Eq a) => a -> v a -> v a
trimArrayTail :: forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayTail a
x v a
array = case v a -> Maybe a
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeLast v a
array of
Maybe a
Nothing -> v a
forall (v :: * -> *) a. Vector v a => v a
DVG.empty
Just a
lastVal ->
if a
lastVal a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
then a -> v a -> v a
forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayTail a
x (v a -> v a
forall (v :: * -> *) a. Vector v a => v a -> v a
DVG.init v a
array)
else v a
array
zipWith ::
(DVG.Vector v a, DVG.Vector v b, DVG.Vector v c) =>
(a -> b -> c) ->
(a -> c) ->
(b -> c) ->
v a ->
v b ->
v c
zipWith :: forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> (a -> c) -> (b -> c) -> v a -> v b -> v c
zipWith a -> b -> c
f a -> c
g b -> c
h v a
a0 v b
a1 = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
l0 Int
l1 of
Ordering
EQ -> v c
base
Ordering
GT -> v c
base v c -> v c -> v c
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
DVG.++ (a -> c) -> v a -> v c
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
DVG.map a -> c
g (Int -> v a -> v a
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
DVG.drop Int
l1 v a
a0)
Ordering
LT -> v c
base v c -> v c -> v c
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
DVG.++ (b -> c) -> v b -> v c
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
DVG.map b -> c
h (Int -> v b -> v b
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
DVG.drop Int
l0 v b
a1)
where
l0 :: Int
l0 = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
DVG.length v a
a0
l1 :: Int
l1 = v b -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
DVG.length v b
a1
base :: v c
base = (a -> b -> c) -> v a -> v b -> v c
forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
DVG.zipWith a -> b -> c
f v a
a0 v b
a1