{-# LANGUAGE PatternSynonyms #-}

-- | Internal module. Not part of the public API.
module Data.PackStream.Get.Internal
  ( getNull, tryNull
  , getBoolean, tryBoolean

  , getFloat, tryFloat

  , getString, tryString
  , getBytes, tryBytes

  , getList, tryList
  , getDictionary, tryDictionary
  , tryStructure
  ) where

import           Compat.Prelude

import qualified Data.ByteString       as S
import qualified Data.Text             as T
import qualified Data.Text.Encoding    as T
import qualified Data.Vector           as V
import qualified Data.HashMap.Lazy     as H
import Data.Hashable (Hashable)

import           Compat.Binary
import           Data.PackStream.Tags

mkGet :: (Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet :: forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8 -> t -> Get a -> Get b
tryT t
f String
n = do { Word8
tag <- Get Word8
getWord8; Word8 -> t -> Get a -> Get b
tryT Word8
tag t
f (String -> Get a
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
n) }

-- | Decode a PackStream null value.
getNull :: Get ()
getNull :: Get ()
getNull = (Word8 -> (() -> ()) -> Get () -> Get ())
-> (() -> ()) -> String -> Get ()
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8 -> (() -> ()) -> Get () -> Get ()
forall a. Word8 -> (() -> a) -> Get a -> Get a
tryNull () -> ()
forall a. a -> a
id String
"expected PackStream null"

-- | Decode a PackStream boolean value.
getBoolean :: Get Bool
getBoolean :: Get Bool
getBoolean = (Word8 -> (Bool -> Bool) -> Get Bool -> Get Bool)
-> (Bool -> Bool) -> String -> Get Bool
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8 -> (Bool -> Bool) -> Get Bool -> Get Bool
forall a. Word8 -> (Bool -> a) -> Get a -> Get a
tryBoolean Bool -> Bool
forall a. a -> a
id String
"expected PackStream boolean"

-- | Decode a PackStream 64-bit float.
getFloat :: Get Double
getFloat :: Get Double
getFloat = (Word8 -> (Double -> Double) -> Get Double -> Get Double)
-> (Double -> Double) -> String -> Get Double
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8 -> (Double -> Double) -> Get Double -> Get Double
forall a. Word8 -> (Double -> a) -> Get a -> Get a
tryFloat Double -> Double
forall a. a -> a
id String
"expected PackStream float"

-- | Decode a PackStream UTF-8 string.
getString :: Get T.Text
getString :: Get Text
getString = (Word8 -> (Text -> Text) -> Get Text -> Get Text)
-> (Text -> Text) -> String -> Get Text
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8 -> (Text -> Text) -> Get Text -> Get Text
forall a. Word8 -> (Text -> a) -> Get a -> Get a
tryString Text -> Text
forall a. a -> a
id String
"expected PackStream string"

-- | Decode a PackStream byte array.
getBytes :: Get S.ByteString
getBytes :: Get ByteString
getBytes = (Word8
 -> (ByteString -> ByteString) -> Get ByteString -> Get ByteString)
-> (ByteString -> ByteString) -> String -> Get ByteString
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet Word8
-> (ByteString -> ByteString) -> Get ByteString -> Get ByteString
forall a. Word8 -> (ByteString -> a) -> Get a -> Get a
tryBytes ByteString -> ByteString
forall a. a -> a
id String
"expected PackStream bytes"

-- | Decode a PackStream list using the given element decoder.
getList :: Get a -> Get (V.Vector a)
getList :: forall a. Get a -> Get (Vector a)
getList Get a
g = (Word8
 -> (Vector a -> Vector a) -> Get (Vector a) -> Get (Vector a))
-> (Vector a -> Vector a) -> String -> Get (Vector a)
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet (Get a
-> Word8
-> (Vector a -> Vector a)
-> Get (Vector a)
-> Get (Vector a)
forall b a. Get b -> Word8 -> (Vector b -> a) -> Get a -> Get a
tryList Get a
g) Vector a -> Vector a
forall a. a -> a
id String
"expected PackStream list"

-- | Decode a PackStream dictionary using the given key and value decoders.
getDictionary :: Hashable a => Get a -> Get b -> Get (H.HashMap a b)
getDictionary :: forall a b. Hashable a => Get a -> Get b -> Get (HashMap a b)
getDictionary Get a
k Get b
v = (Word8
 -> (HashMap a b -> HashMap a b)
 -> Get (HashMap a b)
 -> Get (HashMap a b))
-> (HashMap a b -> HashMap a b) -> String -> Get (HashMap a b)
forall t a b.
(Word8 -> t -> Get a -> Get b) -> t -> String -> Get b
mkGet (Get a
-> Get b
-> Word8
-> (HashMap a b -> HashMap a b)
-> Get (HashMap a b)
-> Get (HashMap a b)
forall k v a.
Hashable k =>
Get k -> Get v -> Word8 -> (HashMap k v -> a) -> Get a -> Get a
tryDictionary Get a
k Get b
v) HashMap a b -> HashMap a b
forall a. a -> a
id String
"Dictionary"


----------------------------------------------------------------------------
-- primitives that take a tag as first argument

-- | Try to decode null given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryNull #-}
tryNull :: Word8 -> (() -> a) -> Get a -> Get a
tryNull :: forall a. Word8 -> (() -> a) -> Get a -> Get a
tryNull Word8
tag () -> a
f Get a
cont = case Word8
tag of
  Word8
TAG_Null -> a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! () -> a
f ()
  Word8
_       -> Get a
cont

-- | Try to decode a boolean given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryBoolean #-}
tryBoolean :: Word8 -> (Bool -> a) -> Get a -> Get a
tryBoolean :: forall a. Word8 -> (Bool -> a) -> Get a -> Get a
tryBoolean Word8
tag Bool -> a
f Get a
cont = case Word8
tag of
  Word8
TAG_false -> a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! Bool -> a
f Bool
False
  Word8
TAG_true  -> a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! Bool -> a
f Bool
True
  Word8
_         -> Get a
cont


-- | Try to decode a float given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryFloat #-}
tryFloat :: Word8 -> (Double -> a) -> Get a -> Get a
tryFloat :: forall a. Word8 -> (Double -> a) -> Get a -> Get a
tryFloat Word8
tag Double -> a
f Get a
cont = case Word8
tag of
  Word8
TAG_Float -> Double -> a
f (Double -> a) -> Get Double -> Get a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Double
getFloat64be
  Word8
_           -> Get a
cont

-- | Try to decode a string given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryString #-}
tryString :: Word8 -> (T.Text -> a) -> Get a -> Get a
tryString :: forall a. Word8 -> (Text -> a) -> Get a -> Get a
tryString Word8
tag Text -> a
f Get a
cont = case Word8
tag of
    Word8
t | Just Word32
sz <- Word8 -> Maybe Word32
is_TAG_STRING_SHORT Word8
t -> Word32 -> Get a
cont' Word32
sz
    Word8
TAG_STRING_8                       -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word8 -> Word32) -> Word8 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word8 -> Get a) -> Get Word8 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word8
getWord8
    Word8
TAG_STRING_16                      -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word16 -> Word32) -> Word16 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word16 -> Get a) -> Get Word16 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word16
getWord16be
    Word8
TAG_STRING_32                      -> Word32 -> Get a
cont' (Word32 -> Get a) -> Get Word32 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word32
getWord32be
    Word8
_                              -> Get a
cont
  where
    cont' :: Word32 -> Get a
cont' Word32
len = do
      Int
len' <- String -> Word32 -> Get Int
fromSizeM String
"tryString: data exceeds capacity of ByteString/Text" Word32
len
      ByteString
bs <- Int -> Get ByteString
getByteString Int
len'
      case ByteString -> Either UnicodeException Text
T.decodeUtf8' ByteString
bs of
        Left UnicodeException
_  -> String -> Get a
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryString: invalid UTF-8 encoding"
        Right Text
v -> a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! Text -> a
f Text
v

-- | Try to decode a byte array given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryBytes #-}
tryBytes :: Word8 -> (S.ByteString -> a) -> Get a -> Get a
tryBytes :: forall a. Word8 -> (ByteString -> a) -> Get a -> Get a
tryBytes Word8
tag ByteString -> a
f Get a
cont = case Word8
tag of
    Word8
TAG_BYTE_ARRAY_8  -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word8 -> Word32) -> Word8 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word8 -> Get a) -> Get Word8 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word8
getWord8
    Word8
TAG_BYTE_ARRAY_16 -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word16 -> Word32) -> Word16 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word16 -> Get a) -> Get Word16 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word16
getWord16be
    Word8
TAG_BYTE_ARRAY_32 -> Word32 -> Get a
cont' (Word32 -> Get a) -> Get Word32 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word32
getWord32be
    Word8
_         -> Get a
cont
  where
    cont' :: Word32 -> Get a
cont' Word32
len = do
      Int
len' <- String -> Word32 -> Get Int
fromSizeM String
"tryBytes: data exceeds capacity of ByteString" Word32
len
      ByteString -> a
f (ByteString -> a) -> Get ByteString -> Get a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get ByteString
getByteString Int
len'

-- | Try to decode a list given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryList #-}
tryList :: Get b -> Word8 -> (V.Vector b -> a) -> Get a -> Get a
tryList :: forall b a. Get b -> Word8 -> (Vector b -> a) -> Get a -> Get a
tryList Get b
g Word8
tag Vector b -> a
f Get a
cont = case Word8
tag of
    Word8
t | Just Word32
sz <- Word8 -> Maybe Word32
is_TAG_LIST_SHORT Word8
t -> Word32 -> Get a
cont' Word32
sz
    Word8
TAG_LIST_8                       -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word8 -> Word32) -> Word8 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word8 -> Get a) -> Get Word8 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word8
getWord8
    Word8
TAG_LIST_16                      -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word16 -> Word32) -> Word16 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word16 -> Get a) -> Get Word16 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word16
getWord16be
    Word8
TAG_LIST_32                      -> Word32 -> Get a
cont' (Word32 -> Get a) -> Get Word32 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word32
getWord32be
    Word8
_                                -> Get a
cont
  where
    cont' :: Word32 -> Get a
cont' Word32
len = do
      Int
len' <- String -> Word32 -> Get Int
fromSizeM String
"tryList: data exceeds capacity of Vector" Word32
len
      Vector b -> a
f (Vector b -> a) -> Get (Vector b) -> Get a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get b -> Get (Vector b)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
len' Get b
g

-- | Try to decode a dictionary given a tag byte; apply @f@ on success or fall through to the continuation.
{-# INLINE tryDictionary #-}
tryDictionary :: Hashable k => Get k -> Get v -> Word8 -> (H.HashMap k v -> a) -> Get a -> Get a
tryDictionary :: forall k v a.
Hashable k =>
Get k -> Get v -> Word8 -> (HashMap k v -> a) -> Get a -> Get a
tryDictionary Get k
k Get v
v Word8
tag HashMap k v -> a
f Get a
cont = case Word8
tag of
    Word8
t | Just Word32
sz <- Word8 -> Maybe Word32
is_TAG_DICTIONARY_SHORT Word8
t -> Word32 -> Get a
cont' Word32
sz
    Word8
TAG_DICTIONARY_8               -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word8 -> Word32) -> Word8 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word8 -> Get a) -> Get Word8 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word8
getWord8
    Word8
TAG_DICTIONARY_16              -> Word32 -> Get a
cont' (Word32 -> Get a) -> (Word16 -> Word32) -> Word16 -> Get a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b.
(Integral a, Integral b, IsIntSubType a b ~ 'True) =>
a -> b
intCast (Word16 -> Get a) -> Get Word16 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word16
getWord16be
    Word8
TAG_DICTIONARY_32              -> Word32 -> Get a
cont' (Word32 -> Get a) -> Get Word32 -> Get a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Get Word32
getWord32be
    Word8
_                              -> Get a
cont
  where
    cont' :: Word32 -> Get a
cont' Word32
len = do
      Int
len' <- String -> Word32 -> Get Int
fromSizeM String
"tryDictionary: data exceeds capacity of Vector" Word32
len
      ([(k, v)] -> a) -> Get [(k, v)] -> Get a
forall a b. (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (HashMap k v -> a
f (HashMap k v -> a) -> ([(k, v)] -> HashMap k v) -> [(k, v)] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(k, v)] -> HashMap k v
forall k v. Hashable k => [(k, v)] -> HashMap k v
H.fromList) (Get [(k, v)] -> Get a) -> Get [(k, v)] -> Get a
forall a b. (a -> b) -> a -> b
$ Int -> Get (k, v) -> Get [(k, v)]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
len' ((,) (k -> v -> (k, v)) -> Get k -> Get (v -> (k, v))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get k
k Get (v -> (k, v)) -> Get v -> Get (k, v)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get v
v)
      -- fmap f $ V.replicateM len' ((,) <$> k <*> v)

-- {-# INLINE tryStructure #-}
-- tryStructure :: Word8 -> (H.HashMap T.Text Ps -> a) -> Get a -> Get a
-- tryStructure tag f cont = case tag of
--   TAG_float64 -> f <$> getFloat64be
--   _           -> cont

-- | Try to decode a structure given a tag byte; returns (structTag, fields) via @f@ on success.
{-# INLINE tryStructure #-}
tryStructure :: Get b -> Word8 -> ((Word8, V.Vector b) -> a) -> Get a -> Get a
tryStructure :: forall b a.
Get b -> Word8 -> ((Word8, Vector b) -> a) -> Get a -> Get a
tryStructure Get b
g Word8
tag (Word8, Vector b) -> a
f Get a
cont = case Word8 -> Maybe Word32
is_TAG_STRUCTURE Word8
tag of
  Just Word32
nfields -> do
    Word8
structTag <- Get Word8
getWord8
    Int
nfields' <- String -> Word32 -> Get Int
fromSizeM String
"tryStructure: data exceeds capacity of Vector" Word32
nfields
    Vector b
fields <- Int -> Get b -> Get (Vector b)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
V.replicateM Int
nfields' Get b
g
    a -> Get a
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Get a) -> a -> Get a
forall a b. (a -> b) -> a -> b
$! (Word8, Vector b) -> a
f (Word8
structTag, Vector b
fields)
  Maybe Word32
Nothing -> Get a
cont

fromSizeM :: String -> Word32 -> Get Int
fromSizeM :: String -> Word32 -> Get Int
fromSizeM String
label Word32
sz = Get Int -> (Int -> Get Int) -> Maybe Int -> Get Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Get Int
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
label) Int -> Get Int
forall a. a -> Get a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32 -> Maybe Int
forall a b.
(Integral a, Integral b, Bits a, Bits b) =>
a -> Maybe b
intCastMaybe Word32
sz)