module Numeric.BLAS.Vector.Mutable ( T, Sourced, C, shape, fromVector, sourcedFromVector, slice, sliceVector, slices, slicesVector, new, thawSlice, fromChunk, add, sub, mac, ) where import qualified Numeric.BLAS.Subobject.Mutable as MutSub import qualified Numeric.BLAS.Vector.SlicePrivate as VectorSlice import qualified Numeric.BLAS.Vector.Chunk as Chunk import qualified Numeric.BLAS.Subobject.Shape as Subshape -- import qualified Numeric.BLAS.Subobject.View as View import qualified Numeric.BLAS.Subobject.Layout.Class as LayoutClass import qualified Numeric.BLAS.Subobject.Layout as Layout import qualified Numeric.BLAS.Slice as Slice import qualified Numeric.BLAS.Scalar as Scalar import Numeric.BLAS.Subobject.Mutable (startArg, shape) import Numeric.BLAS.Vector.SlicePrivate ((<*|>)) import qualified Numeric.BLAS.FFI.Generic as Blas -- import qualified Numeric.BLAS.FFI.Complex as BlasComplex -- import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.Netlib.Class as Class import qualified Numeric.Netlib.Utility as Call import Numeric.BLAS.Private (fill) import qualified Data.Array.Comfort.Storable.Mutable.Private as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Mutable.Private (Array(Array)) import qualified Foreign.Marshal.Array.Guarded as ForeignArray import Foreign.Marshal.Array (advancePtr) -- import Foreign.ForeignPtr (withForeignPtr, castForeignPtr) import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import Foreign.C.Types (CInt) import Control.Monad.Primitive (PrimMonad, unsafeIOToPrim) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative (liftA2, (<$>)) type Vector = Array type ShapeInt = Shape.ZeroBased Int type T = MutSub.T Layout.Slice data Sourced sh m slice a = Sourced (Slice.T slice) (Vector m sh a) -- deriving (Show) fromVector :: (Shape.C sh) => Vector m sh a -> T m sh a fromVector a = MutSub.Cons $ Array.mapShape Subshape.fromVector a sourcedFromVector :: (Shape.C sh) => Vector m sh a -> Sourced sh m sh a sourcedFromVector a = Sourced (Slice.fromShape $ Array.shape a) a -- ToDo: generalize and move to Subobject.Mutable slice :: (Slice.T shA -> Slice.T shB) -> T m shA a -> T m shB a slice f (MutSub.Cons xs) = MutSub.Cons $ Array.mapShape (Subshape.focus f) xs sliceVector :: (Shape.C shA) => (Slice.T shA -> Slice.T shB) -> Vector m shA a -> T m shB a sliceVector f = slice f . fromVector slices :: (Functor f) => (Slice.T shA -> f (Slice.T shB)) -> T m shA a -> f (T m shB a) slices f (MutSub.Cons (Array sh x)) = MutSub.Cons . flip Array x <$> Subshape.focusMany f sh slicesVector :: (Functor f, Shape.C shA) => (Slice.T shA -> f (Slice.T shB)) -> Vector m shA a -> f (T m shB a) slicesVector f = slices f . fromVector instance MutSub.C (Sourced sh) where shape (Sourced (Slice.Cons _s _k slc) _arr) = slc startArg (Sourced (Slice.Cons s _k _slice) (Array _sh x)) = do xPtr <- ContT $ ForeignArray.withMutablePtr x return (advancePtr xPtr s) -- ToDo: increment must not be zero (maybe even positive) class (MutSub.C v) => C v where increment :: (PrimMonad m) => v m slice a -> Int instance C Array where increment _arr = 1 instance (LayoutClass.C lay) => C (MutSub.T lay) where increment (MutSub.Cons arr) = LayoutClass.elemInc $ Subshape.layout $ Array.shape arr instance C (Sourced sh) where increment (Sourced (Slice.Cons _s k _slc) _arr) = LayoutClass.elemInc k sliceArg :: (PrimMonad m, C v, Storable a) => v m slice a -> Call.FortranIO r (Ptr a, Ptr CInt) sliceArg x = liftA2 (,) (startArg x) (Call.cint $ increment x) new :: (PrimMonad m, Shape.C sh, Class.Floating a) => sh -> a -> m (Vector m sh a) new sh x = Array.unsafeCreateWithSize sh $ \size ptr -> fill x size ptr fromChunk :: (PrimMonad m, Storable a) => Chunk.T a -> m (Vector m ShapeInt a) fromChunk = Chunk.toMutableVector -- cf. Vector.Slice.toVector thawSlice :: (PrimMonad m, VectorSlice.C v, Shape.C sh, Class.Floating a) => v sh a -> m (Vector m sh a) thawSlice x = Array.unsafeCreateWithSize (VectorSlice.shape x) $ \n yPtr -> evalContT $ Call.run $ pure Blas.copy <*> Call.cint n <*|> x <*> pure yPtr <*> Call.cint 1 add, sub :: (PrimMonad m, C v, VectorSlice.C w, Shape.C sh, Eq sh, Class.Floating a) => v m sh a -> w sh a -> m () add = flip mac Scalar.one sub = flip mac Scalar.minusOne mac :: (PrimMonad m, C v, VectorSlice.C w, Shape.C sh, Eq sh, Class.Floating a) => v m sh a -> a -> w sh a -> m () mac y alpha x = unsafeIOToPrim $ do let sh = VectorSlice.shape x Call.assert "mac: shapes mismatch" (sh == shape y) evalContT $ do let n = Shape.size sh nPtr <- Call.cint n alphaPtr <- Call.number alpha (xPtr, incxPtr) <- VectorSlice.sliceArg x (yPtr, incyPtr) <- sliceArg y liftIO $ Blas.axpy nPtr alphaPtr xPtr incxPtr yPtr incyPtr