-- | Module    :  Numeric.InfBackprop.Instances.NumHask
-- Copyright   :  (C) 2025 Alexey Tochin
-- License     :  BSD3 (see the file LICENSE)
-- Maintainer  :  Alexey Tochin <Alexey.Tochin@gmail.com>
--
-- Utility functions for working with tuples.
module Numeric.InfBackprop.Utils.Tuple
  ( cross,
    cross3,
    fork,
    fork3,
    curry3,
    uncurry3,
    biCross,
    biCross3,
  )
where

-- | Applies two functions to the components of a tuple.

--- ==== __Examples__
--
-- >>> cross (+1) (*2) (3, 4)
-- (4,8)
cross :: (a -> b) -> (c -> d) -> (a, c) -> (b, d)
{-# INLINE cross #-}
cross :: forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
cross a -> b
f c -> d
g (a
x, c
y) = (a -> b
f a
x, c -> d
g c
y)

-- | Applies three functions to the components of a triple.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+), (-), (*))
--
-- >>> cross3 (+1) (*2) (\x -> x - 3) (3, 4, 10)
-- (4,8,7)
cross3 :: (a0 -> b0) -> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
{-# INLINE cross3 #-}
cross3 :: forall a0 b0 a1 b1 a2 b2.
(a0 -> b0)
-> (a1 -> b1) -> (a2 -> b2) -> (a0, a1, a2) -> (b0, b1, b2)
cross3 a0 -> b0
f a1 -> b1
g a2 -> b2
h (a0
x, a1
y, a2
z) = (a0 -> b0
f a0
x, a1 -> b1
g a1
y, a2 -> b2
h a2
z)

-- | Applies two functions to the same argument and returns a tuple of results.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+), (*))
--
-- >>> fork (+1) (*2) 3
-- (4,6)
fork :: (t -> a) -> (t -> b) -> t -> (a, b)
{-# INLINE fork #-}
fork :: forall t a b. (t -> a) -> (t -> b) -> t -> (a, b)
fork t -> a
f t -> b
g t
x = (t -> a
f t
x, t -> b
g t
x)

-- | Applies three functions to the same argument and returns a triple of results.
--
-- >>> import GHC.Num ((+), (-), (*))
--
-- ==== __Examples__
--
-- >>> fork3 (+1) (*2) (\x -> x - 3) 5
-- (6,10,2)
fork3 :: (t -> a0) -> (t -> a1) -> (t -> a2) -> t -> (a0, a1, a2)
{-# INLINE fork3 #-}
fork3 :: forall t a0 a1 a2.
(t -> a0) -> (t -> a1) -> (t -> a2) -> t -> (a0, a1, a2)
fork3 t -> a0
f0 t -> a1
f1 t -> a2
f2 t
x = (t -> a0
f0 t
x, t -> a1
f1 t
x, t -> a2
f2 t
x)

-- | Curries a function on triples.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+))
--
-- >>> f (x, y, z) = x + y + z
-- >>> g = curry3 f
-- >>> g 1 2 3
-- 6
curry3 :: ((a, b, c) -> d) -> a -> b -> c -> d
{-# INLINE curry3 #-}
curry3 :: forall a b c d. ((a, b, c) -> d) -> a -> b -> c -> d
curry3 (a, b, c) -> d
f a
x b
y c
z = (a, b, c) -> d
f (a
x, b
y, c
z)

-- | Uncurries a function on triples.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+))
--
-- >>> f x y z = x + y + z
-- >>> g = uncurry3 f
-- >>> g (1, 2, 3)
-- 6
uncurry3 :: (a -> b -> c -> d) -> ((a, b, c) -> d)
{-# INLINE uncurry3 #-}
uncurry3 :: forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 a -> b -> c -> d
f (a
x, b
y, c
z) = a -> b -> c -> d
f a
x b
y c
z

-- | Applies two binary functions to the components of two tuples.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+), (*))
--
-- >>> biCross (+) (*) (1, 2) (3, 4)
-- (4,8)
biCross :: (a -> b -> c) -> (d -> e -> f) -> (a, d) -> (b, e) -> (c, f)
{-# INLINE biCross #-}
biCross :: forall a b c d e f.
(a -> b -> c) -> (d -> e -> f) -> (a, d) -> (b, e) -> (c, f)
biCross a -> b -> c
f d -> e -> f
g (a
x0, d
x1) (b
y0, e
y1) = (a -> b -> c
f a
x0 b
y0, d -> e -> f
g d
x1 e
y1)

-- | Applies three binary functions to the components of two triples.
--
-- ==== __Examples__
--
-- >>> import GHC.Num ((+), (*), (-))
--
-- >>> biCross3 (+) (*) (-) (1, 2, 10) (3, 4, 5)
-- (4,8,5)
biCross3 ::
  (a -> b -> c) ->
  (d -> e -> f) ->
  (g -> h -> l) ->
  (a, d, g) ->
  (b, e, h) ->
  (c, f, l)
{-# INLINE biCross3 #-}
biCross3 :: forall a b c d e f g h l.
(a -> b -> c)
-> (d -> e -> f)
-> (g -> h -> l)
-> (a, d, g)
-> (b, e, h)
-> (c, f, l)
biCross3 a -> b -> c
f d -> e -> f
g g -> h -> l
h (a
x0, d
x1, g
x2) (b
y0, e
y1, h
y2) = (a -> b -> c
f a
x0 b
y0, d -> e -> f
g d
x1 e
y1, g -> h -> l
h g
x2 h
y2)