-- | Module    :  Numeric.InfBackprop.Instances.NumHask
-- Copyright   :  (C) 2025 Alexey Tochin
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  Alexey Tochin <Alexey.Tochin@gmail.com>
--
-- Utility functions for working with vectors.
module Numeric.InfBackprop.Utils.Vector
  ( fromTuple,
    safeHead,
    safeLast,
    trimArrayHead,
    trimArrayTail,
    zipWith,
  )
where

import Control.Monad (MonadPlus, mzero)
import Data.Bool (otherwise)
import Data.Eq (Eq, (==))
import Data.Function (($))
import qualified Data.IndexedListLiterals as DILL
import Data.Maybe (Maybe (Just, Nothing))
import Data.Ord (Ordering (EQ, GT, LT), compare)
import qualified Data.Vector.Generic as DVG
import GHC.Base (pure, (.))

-- | Converts a tuple into a Vector (`Data.Vector.Vector`).
--
-- === __Examples__
--
-- >>> import GHC.Int (Int)
-- >>> import qualified Data.Vector as DV
--
-- >>> fromTuple (1 :: Int, 2 :: Int, 3 :: Int) :: DV.Vector Int
-- [1,2,3]
fromTuple ::
  (DVG.Vector v a) =>
  (DILL.IndexedListLiterals input length a) =>
  input ->
  v a
fromTuple :: forall (v :: * -> *) a input (length :: Nat).
(Vector v a, IndexedListLiterals input length a) =>
input -> v a
fromTuple = [a] -> v a
forall (v :: * -> *) a. Vector v a => [a] -> v a
DVG.fromList ([a] -> v a) -> (input -> [a]) -> input -> v a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. input -> [a]
forall input (length :: Nat) output.
IndexedListLiterals input length output =>
input -> [output]
DILL.toList

-- | Returns the first element of a vector safely.
-- If the vector is empty, it returns 'Nothing'.
--
-- ==== __Examples__
--
-- >>> import GHC.Int (Int)
-- >>> import Data.Vector (fromList)
--
-- >>> safeHead (fromList [1, 2, 3]) :: Maybe Int
-- Just 1
--
-- >>> safeHead (fromList []) :: Maybe Int
-- Nothing
safeHead :: (DVG.Vector v a, MonadPlus m) => v a -> m a
safeHead :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeHead v a
vec
  | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
DVG.null v a
vec = m a
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
  | Bool
otherwise = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
DVG.unsafeHead v a
vec

-- | Returns the last element of a vector safely.
-- If the vector is empty, it returns 'Nothing'.
--
-- ==== __Examples__
--
-- >>> import GHC.Int (Int)
-- >>> import Data.Vector (fromList, empty)
--
-- >>> safeLast (fromList [1, 2, 3]) :: Maybe Int
-- Just 3
--
-- >>> safeLast empty :: Maybe Int
-- Nothing
safeLast :: (DVG.Vector v a, MonadPlus m) => v a -> m a
safeLast :: forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeLast v a
vec
  | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
DVG.null v a
vec = m a
forall a. m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero
  | Bool
otherwise = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
DVG.unsafeLast v a
vec

-- | Removes elements from the beginning of the vector until the first element
-- is not equal to the given value.
--
-- ==== __Examples__
--
-- >>> import Data.Vector (fromList, empty)
--
-- >>> trimArrayHead 1 (fromList [1, 1, 1, 2, 3])
-- [2,3]
--
-- >>> trimArrayHead 1 empty
-- []
trimArrayHead :: (DVG.Vector v a, Eq a) => a -> v a -> v a
trimArrayHead :: forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayHead a
x v a
vec = case v a -> Maybe a
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeHead v a
vec of
  Maybe a
Nothing -> v a
forall (v :: * -> *) a. Vector v a => v a
DVG.empty
  Just a
firstVal ->
    if a
firstVal a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
      then a -> v a -> v a
forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayHead a
x (v a -> v a
forall (v :: * -> *) a. Vector v a => v a -> v a
DVG.tail v a
vec)
      else v a
vec

-- | Removes elements from the end of the vector until the last element
-- is not equal to the given value.
--
-- ==== __Examples__
--
-- >>> import Data.Vector (fromList, empty)
--
-- >>> trimArrayTail 3 (fromList [1, 2, 3, 3, 3])
-- [1,2]
--
-- >>> trimArrayTail 3 empty
-- []
trimArrayTail :: (DVG.Vector v a, Eq a) => a -> v a -> v a
trimArrayTail :: forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayTail a
x v a
array = case v a -> Maybe a
forall (v :: * -> *) a (m :: * -> *).
(Vector v a, MonadPlus m) =>
v a -> m a
safeLast v a
array of
  Maybe a
Nothing -> v a
forall (v :: * -> *) a. Vector v a => v a
DVG.empty
  Just a
lastVal ->
    if a
lastVal a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x
      then a -> v a -> v a
forall (v :: * -> *) a. (Vector v a, Eq a) => a -> v a -> v a
trimArrayTail a
x (v a -> v a
forall (v :: * -> *) a. Vector v a => v a -> v a
DVG.init v a
array)
      else v a
array

-- | Combines two arrays of different lengths using a custom function.
-- The resulting array has a length equal to the maximum of the two input vectors.
-- The shorter array is padded with values generated by the provided functions.
--
-- ==== __Examples__
--
-- >>> import Prelude (id, negate, (-), Int)
-- >>> import qualified Data.Vector as DV
--
-- The following example demonstrates subtracting two arrays of different lengths.
-- The shorter array is padded with zeros, and the remaining elements are processed
-- using the provided functions.
--
-- >>>:{
--  zipWith
--    (-)                         -- Subtract corresponding elements from the two arrays
--    id                          -- Keep the remaining elements of the first array unchanged
--    negate                      -- Negate the remaining elements of the second array
--    (DV.fromList [10, 20, 30])  -- First array
--    (DV.fromList [1, 2])        -- Second array
-- :}
-- [9,18,30]
--
-- >>> import Prelude (id, negate, (-), Int)
-- >>> import Data.Vector (fromList)
--
-- >>> let v0 :: DV.Vector Int = DV.fromList [10, 20, 30]
-- >>> let v1 :: DV.Vector Int = DV.fromList [1, 2]
-- >>> zipWith (-) id negate v0 v1
-- [9,18,30]
zipWith ::
  (DVG.Vector v a, DVG.Vector v b, DVG.Vector v c) =>
  (a -> b -> c) ->
  (a -> c) ->
  (b -> c) ->
  v a ->
  v b ->
  v c
zipWith :: forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> (a -> c) -> (b -> c) -> v a -> v b -> v c
zipWith a -> b -> c
f a -> c
g b -> c
h v a
a0 v b
a1 = case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
l0 Int
l1 of
  Ordering
EQ -> v c
base
  Ordering
GT -> v c
base v c -> v c -> v c
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
DVG.++ (a -> c) -> v a -> v c
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
DVG.map a -> c
g (Int -> v a -> v a
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
DVG.drop Int
l1 v a
a0)
  Ordering
LT -> v c
base v c -> v c -> v c
forall (v :: * -> *) a. Vector v a => v a -> v a -> v a
DVG.++ (b -> c) -> v b -> v c
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
DVG.map b -> c
h (Int -> v b -> v b
forall (v :: * -> *) a. Vector v a => Int -> v a -> v a
DVG.drop Int
l0 v b
a1)
  where
    l0 :: Int
l0 = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
DVG.length v a
a0
    l1 :: Int
l1 = v b -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
DVG.length v b
a1
    base :: v c
base = (a -> b -> c) -> v a -> v b -> v c
forall (v :: * -> *) a b c.
(Vector v a, Vector v b, Vector v c) =>
(a -> b -> c) -> v a -> v b -> v c
DVG.zipWith a -> b -> c
f v a
a0 v b
a1