-- | Internal module. Not part of the public API.
module Database.Bolty.Connection.Pipe
  ( connect
  , close
  , reset
  , ping
  , flush
  , fetch
  , sendRequest
  , receiveResponse
  , requireState
  , setState
  , getState
  , logon
  , logoff
  , sendTelemetry
  , MonadPipe
  -- * Connection accessors
  , connectionVersion
  , connectionAgent
  , connectionId
  , connectionTelemetryEnabled
  , connectionServerIdleTimeout
  , touchConnection
  , connectionLastActivity
  -- * Transaction primitives (plain IO)
  , beginTx
  , commitTx
  , rollbackTx
  , tryRollback
  -- * Plain IO wire helpers
  , requireStateIO
  , flushIO
  , fetchIO
  ) where

import           Data.Kind              (Constraint, Type)
import           Control.Exception      (Exception, SomeException, catch, throwIO, try)
import           Control.Monad          (when, unless)
import           Control.Monad.Trans    (MonadIO(..))
import           Control.Monad.Extra    (whenMaybe, whileJustM)
import           Control.Monad.Except
import           Data.Persist            (putBE, runPutLazy)
import qualified Data.Persist            as P
import           Data.IORef             (newIORef, readIORef, writeIORef)
import           GHC.Clock             (getMonotonicTimeNSec)
import           Data.Bits                      (shiftR, (.&.))
import           Data.Int                       (Int64)
import           Data.Word
import           Data.PackStream.Integer        (fromPSInteger)
import           GHC.Stack              (HasCallStack)
import           Prelude
import qualified Data.ByteString        as BS
import qualified Data.ByteString.Lazy   as BSL
import qualified Data.HashMap.Lazy      as H
import qualified Data.Text              as T
import qualified Data.PackStream        as PS
import           Data.PackStream.Ps     (Ps(..))
import qualified Data.Vector            as V
import qualified Network.Connection     as NC

import           Database.Bolty.Connection.Type
import qualified Database.Bolty.Connection.Connection as C
import           Database.Bolty.Message.Response
import           Database.Bolty.Message.Request
import           Database.Bolty.Value.Helpers (isNewVersion, versionMajor, supportsLogonLogoff, supportsTelemetry)
import           Database.Bolty.Util


-- | Constraint alias for monads that can perform BOLT pipe operations.
type MonadPipe :: (Type -> Type) -> Constraint
type MonadPipe m = (HasCallStack, MonadIO m, MonadError Error m)


exceptToThrow :: (Exception e, MonadIO m) => ExceptT e m b -> m b
exceptToThrow :: forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow ExceptT e m b
f = ExceptT e m b -> m (Either e b)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT e m b
f m (Either e b) -> (Either e b -> m b) -> m b
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Right b
y -> b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
y
  Left e
e  -> IO b -> m b
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO b -> m b) -> IO b -> m b
forall a b. (a -> b) -> a -> b
$ e -> IO b
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO e
e


-- | Connect to a Neo4j server, perform the BOLT handshake and authentication.
connect :: (MonadIO m, HasCallStack) => ValidatedConfig -> m Connection
connect :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ValidatedConfig -> m Connection
connect ValidatedConfig{Bool
Int
[Word32]
Maybe (QueryLog -> QueryMeta -> IO ())
Maybe (Notification -> IO ())
Word16
Text
UserAgent
Routing
Scheme
host :: Text
port :: Word16
scheme :: Scheme
use_tls :: Bool
versions :: [Word32]
timeout :: Int
routing :: Routing
user_agent :: UserAgent
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler :: ValidatedConfig -> Maybe (Notification -> IO ())
queryLogger :: ValidatedConfig -> Maybe (QueryLog -> QueryMeta -> IO ())
user_agent :: ValidatedConfig -> UserAgent
routing :: ValidatedConfig -> Routing
timeout :: ValidatedConfig -> Int
versions :: ValidatedConfig -> [Word32]
use_tls :: ValidatedConfig -> Bool
scheme :: ValidatedConfig -> Scheme
port :: ValidatedConfig -> Word16
host :: ValidatedConfig -> Text
..} = do
  (rawConn, tout) <- Bool -> Text -> Word16 -> Int -> m (Connection, Int)
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Bool -> Text -> Word16 -> Int -> m (Connection, Int)
C.connect Bool
use_tls Text
host Word16
port Int
timeout
  (server_version, agent, conn_id, telem, idleTimeout) <- exceptToThrow $ handshake rawConn tout ValidatedConfig{..}
  stateRef <- liftIO $ newIORef Ready
  now <- liftIO getMonotonicTimeNSec
  actRef <- liftIO $ newIORef now
  pure $ Connection { rawConnection = rawConn, timeout_milliseconds = tout
                    , version = server_version, server_state = stateRef
                    , server_agent = agent, connection_id = conn_id, lastActivity = actRef
                    , telemetry_enabled = telem, serverIdleTimeout = idleTimeout
                    , queryLogger = queryLogger
                    , notificationHandler = notificationHandler }


-- | Send GOODBYE (if supported) and close the connection.
close :: (MonadIO m, HasCallStack) => Connection -> m ()
close :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
close Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, Word32
version :: Connection -> Word32
version :: Word32
version, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = ExceptT Error m () -> m ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error m () -> m ()) -> ExceptT Error m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  _ <- Bool -> ExceptT Error m () -> ExceptT Error m (Maybe ())
forall (m :: * -> *) a. Applicative m => Bool -> m a -> m (Maybe a)
whenMaybe (Word32 -> Bool
isNewVersion Word32
version) (ExceptT Error m () -> ExceptT Error m (Maybe ()))
-> ExceptT Error m () -> ExceptT Error m (Maybe ())
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> Request -> ExceptT Error m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RGoodbye
  liftIO $ writeIORef server_state Disconnected
  C.close rawConnection timeout_milliseconds


-- | Send RESET and transition the connection back to the Ready state.
reset :: (MonadIO m, HasCallStack) => Connection -> m ()
reset :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = ExceptT Error m () -> m ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error m () -> m ()) -> ExceptT Error m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Connection -> Int -> Request -> ExceptT Error m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RReset
  response <- Connection -> Int -> ExceptT Error m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    (RSuccess HashMap Text Ps
_) -> IO () -> ExceptT Error m ()
forall a. IO a -> ExceptT Error m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error m ()) -> IO () -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Ready
    Response
_            -> Error -> ExceptT Error m ()
forall a. Error -> ExceptT Error m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
ResetFailed


-- | Check if a Pipe is alive by sending RESET and expecting SUCCESS.
-- Returns True if healthy, False otherwise. Catches all exceptions.
ping :: MonadIO m => Connection -> m Bool
ping :: forall (m :: * -> *). MonadIO m => Connection -> m Bool
ping Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
  result <- forall e a. Exception e => IO a -> IO (Either e a)
try @SomeException (IO (Either Error ())
 -> IO (Either SomeException (Either Error ())))
-> IO (Either Error ())
-> IO (Either SomeException (Either Error ()))
forall a b. (a -> b) -> a -> b
$ ExceptT Error IO () -> IO (Either Error ())
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT Error IO () -> IO (Either Error ()))
-> ExceptT Error IO () -> IO (Either Error ())
forall a b. (a -> b) -> a -> b
$ do
    Connection -> Int -> Request -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RReset
    response <- Connection -> Int -> ExceptT Error IO Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
    case response of
      (RSuccess HashMap Text Ps
_) -> IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Ready
      Response
_            -> Error -> ExceptT Error IO ()
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
ResetFailed
  pure $ case result of
    Right (Right ()
_) -> Bool
True
    Either SomeException (Either Error ())
_               -> Bool
False


-- | Check that the connection is in one of the allowed states, or throw 'InvalidState'
requireState :: MonadPipe m => Connection -> [ServerState] -> T.Text -> m ()
requireState :: forall (m :: * -> *).
MonadPipe m =>
Connection -> [ServerState] -> Text -> m ()
requireState Connection{IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} [ServerState]
allowed Text
action = do
  st <- IO ServerState -> m ServerState
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ServerState -> m ServerState)
-> IO ServerState -> m ServerState
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> IO ServerState
forall a. IORef a -> IO a
readIORef IORef ServerState
server_state
  unless (st `elem` allowed) $
    throwError $ InvalidState st action


-- | Update the server state
setState :: MonadIO m => Connection -> ServerState -> m ()
setState :: forall (m :: * -> *).
MonadIO m =>
Connection -> ServerState -> m ()
setState Connection{IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} ServerState
st = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
st


-- | Get the current server state
getState :: MonadIO m => Connection -> m ServerState
getState :: forall (m :: * -> *). MonadIO m => Connection -> m ServerState
getState Connection{IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = IO ServerState -> m ServerState
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ServerState -> m ServerState)
-> IO ServerState -> m ServerState
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> IO ServerState
forall a. IORef a -> IO a
readIORef IORef ServerState
server_state


handshake :: MonadPipe m => NC.Connection -> Int -> ValidatedConfig -> m (Word32, T.Text, T.Text, Bool, Maybe Int)
handshake :: forall (m :: * -> *).
MonadPipe m =>
Connection
-> Int
-> ValidatedConfig
-> m (Word32, Text, Text, Bool, Maybe Int)
handshake Connection
rawConn Int
tout ValidatedConfig{[Word32]
versions :: ValidatedConfig -> [Word32]
versions :: [Word32]
versions, UserAgent
user_agent :: ValidatedConfig -> UserAgent
user_agent :: UserAgent
user_agent, scheme :: ValidatedConfig -> Scheme
scheme = Scheme
authScheme, Routing
routing :: ValidatedConfig -> Routing
routing :: Routing
routing} = do
    -- https://neo4j.com/docs/bolt/current/bolt/handshake/#_bolt_identification
    Connection -> Int -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> ByteString -> m ()
C.send Connection
rawConn Int
tout (Word32 -> ByteString
forall a. HasEndianness a => a -> ByteString
encodeStrict (Word32
0x6060B017 :: Word32))
    Connection -> Int -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> ByteString -> m ()
C.send Connection
rawConn Int
tout (ByteString -> m ()) -> ByteString -> m ()
forall a b. (a -> b) -> a -> b
$ Put () -> ByteString
forall a. Put a -> ByteString
P.runPut (Put () -> ByteString) -> Put () -> ByteString
forall a b. (a -> b) -> a -> b
$ (Word32 -> Put ()) -> [Word32] -> Put ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Word32 -> Put ()
forall a. Persist (BigEndian a) => a -> Put ()
putBE [Word32]
versions
    server_version :: Word32 <- Connection -> Int -> Int -> m Word32
forall a (m :: * -> *).
(HasEndianness a, MonadIO m, HasCallStack) =>
Connection -> Int -> Int -> m a
C.receiveBinary Connection
rawConn Int
tout Int
4
    when (not $ versionAccepted server_version versions) $ throwError $ UnsupportedServerVersion server_version
    if supportsLogonLogoff server_version then do
      -- Bolt 5.1+: HELLO without credentials (no scheme/principal/credentials fields)
      let serverMinor = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
server_version Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
      let helloDict = HashMap Text Ps -> Ps
PsDictionary (HashMap Text Ps -> Ps) -> HashMap Text Ps -> Ps
forall a b. (a -> b) -> a -> b
$ [(Text, Ps)] -> HashMap Text Ps
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
H.fromList ([(Text, Ps)] -> HashMap Text Ps)
-> [(Text, Ps)] -> HashMap Text Ps
forall a b. (a -> b) -> a -> b
$
            [ (Text
"user_agent", UserAgent -> Ps
forall a. PackStream a => a -> Ps
PS.toPs UserAgent
user_agent)
            ] [(Text, Ps)] -> [(Text, Ps)] -> [(Text, Ps)]
forall a. Semigroup a => a -> a -> a
<>
            -- Bolt 5.3+ requires bolt_agent dictionary
            ( if Word8
serverMinor Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
3 then
                [(Text
"bolt_agent", HashMap Text Ps -> Ps
PsDictionary (HashMap Text Ps -> Ps) -> HashMap Text Ps -> Ps
forall a b. (a -> b) -> a -> b
$ Text -> Ps -> HashMap Text Ps
forall k v. Hashable k => k -> v -> HashMap k v
H.singleton Text
"product" (UserAgent -> Ps
forall a. PackStream a => a -> Ps
PS.toPs UserAgent
user_agent))]
              else
                []
            ) [(Text, Ps)] -> [(Text, Ps)] -> [(Text, Ps)]
forall a. Semigroup a => a -> a -> a
<>
            ( case Routing
routing of
                Routing
NoRouting                          -> []
                Routing
Routing                            -> [(Text
"routing", HashMap Text Ps -> Ps
PsDictionary HashMap Text Ps
forall k v. HashMap k v
H.empty)]
                RoutingSpec Text
address HashMap Text Text
urlQueryParams  -> [(Text
"routing", HashMap Text Ps -> Ps
PsDictionary (HashMap Text Ps -> Ps) -> HashMap Text Ps -> Ps
forall a b. (a -> b) -> a -> b
$ [(Text, Ps)] -> HashMap Text Ps
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
H.fromList ([(Text, Ps)] -> HashMap Text Ps)
-> [(Text, Ps)] -> HashMap Text Ps
forall a b. (a -> b) -> a -> b
$
                    (Text
"address", Text -> Ps
forall a. PackStream a => a -> Ps
PS.toPs Text
address) (Text, Ps) -> [(Text, Ps)] -> [(Text, Ps)]
forall a. a -> [a] -> [a]
: [(Text
k, Text -> Ps
forall a. PackStream a => a -> Ps
PS.toPs Text
v) | (Text
k, Text
v) <- HashMap Text Text -> [(Text, Text)]
forall k v. HashMap k v -> [(k, v)]
H.toList HashMap Text Text
urlQueryParams])]
            )
      sendPs rawConn tout $ PsStructure 0x01 $ V.singleton helloDict
      helloResponse <- receiveResponse rawConn tout
      case helloResponse of
        RSuccess HashMap Text Ps
meta -> do
          let agent :: Text
agent   = Ps -> Text
psToText (Ps -> Text) -> Ps -> Text
forall a b. (a -> b) -> a -> b
$ Ps -> Text -> HashMap Text Ps -> Ps
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
H.lookupDefault (Text -> Ps
PsString Text
"") Text
"server" HashMap Text Ps
meta
          let conn_id :: Text
conn_id = Ps -> Text
psToText (Ps -> Text) -> Ps -> Text
forall a b. (a -> b) -> a -> b
$ Ps -> Text -> HashMap Text Ps -> Ps
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
H.lookupDefault (Text -> Ps
PsString Text
"") Text
"connection_id" HashMap Text Ps
meta
          let telem :: Bool
telem   = HashMap Text Ps -> Bool
parseTelemetryHint HashMap Text Ps
meta
          let idleTO :: Maybe Int
idleTO  = HashMap Text Ps -> Maybe Int
parseIdleTimeoutHint HashMap Text Ps
meta
          -- Now send LOGON with credentials
          Connection -> Int -> Request -> m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConn Int
tout (Request -> m ()) -> Request -> m ()
forall a b. (a -> b) -> a -> b
$ Logon -> Request
RLogon (Logon -> Request) -> Logon -> Request
forall a b. (a -> b) -> a -> b
$ Scheme -> Logon
Logon Scheme
authScheme
          logonResponse <- Connection -> Int -> m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConn Int
tout
          case logonResponse of
            RSuccess HashMap Text Ps
_ -> (Word32, Text, Text, Bool, Maybe Int)
-> m (Word32, Text, Text, Bool, Maybe Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32
server_version, Text
agent, Text
conn_id, Bool
telem, Maybe Int
idleTO)
            Response
_          -> Error -> m (Word32, Text, Text, Bool, Maybe Int)
forall a. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
AuthentificationFailed
        Response
_ -> Error -> m (Word32, Text, Text, Bool, Maybe Int)
forall a. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
AuthentificationFailed
    else do
      -- Bolt 5.0 and below: HELLO with credentials
      -- BOLT 4.x: include patch_bolt to get UTC-based DateTime (0x49/0x69) instead of legacy format
      let needsPatch = Word32 -> Word8
versionMajor Word32
server_version Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
4
      sendRequest rawConn tout $ RHello $ Hello user_agent authScheme routing needsPatch
      response <- receiveResponse rawConn tout
      case response of
        RSuccess HashMap Text Ps
meta -> do
          let agent :: Text
agent   = Ps -> Text
psToText (Ps -> Text) -> Ps -> Text
forall a b. (a -> b) -> a -> b
$ Ps -> Text -> HashMap Text Ps -> Ps
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
H.lookupDefault (Text -> Ps
PsString Text
"") Text
"server" HashMap Text Ps
meta
          let conn_id :: Text
conn_id = Ps -> Text
psToText (Ps -> Text) -> Ps -> Text
forall a b. (a -> b) -> a -> b
$ Ps -> Text -> HashMap Text Ps -> Ps
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
H.lookupDefault (Text -> Ps
PsString Text
"") Text
"connection_id" HashMap Text Ps
meta
          let idleTO :: Maybe Int
idleTO  = HashMap Text Ps -> Maybe Int
parseIdleTimeoutHint HashMap Text Ps
meta
          (Word32, Text, Text, Bool, Maybe Int)
-> m (Word32, Text, Text, Bool, Maybe Int)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Word32
server_version, Text
agent, Text
conn_id, Bool
False, Maybe Int
idleTO)
        Response
_ -> Error -> m (Word32, Text, Text, Bool, Maybe Int)
forall a. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
AuthentificationFailed
  where
    psToText :: Ps -> Text
psToText (PsString Text
t) = Text
t
    psToText Ps
_            = Text
""

    -- Parse hints.telemetry.enabled from HELLO SUCCESS metadata
    parseTelemetryHint :: H.HashMap T.Text Ps -> Bool
    parseTelemetryHint :: HashMap Text Ps -> Bool
parseTelemetryHint HashMap Text Ps
meta =
      case Text -> HashMap Text Ps -> Maybe Ps
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup Text
"hints" HashMap Text Ps
meta of
        Just (PsDictionary HashMap Text Ps
hints) ->
          case Text -> HashMap Text Ps -> Maybe Ps
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup Text
"telemetry.enabled" HashMap Text Ps
hints of
            Just (PsBoolean Bool
b) -> Bool
b
            Maybe Ps
_                  -> Bool
False
        Maybe Ps
_ -> Bool
False

    -- Parse hints.connection.recv_timeout_seconds from HELLO SUCCESS metadata
    parseIdleTimeoutHint :: H.HashMap T.Text Ps -> Maybe Int
    parseIdleTimeoutHint :: HashMap Text Ps -> Maybe Int
parseIdleTimeoutHint HashMap Text Ps
meta =
      case Text -> HashMap Text Ps -> Maybe Ps
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup Text
"hints" HashMap Text Ps
meta of
        Just (PsDictionary HashMap Text Ps
hints) ->
          case Text -> HashMap Text Ps -> Maybe Ps
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup Text
"connection.recv_timeout_seconds" HashMap Text Ps
hints of
            Just (PsInteger PSInteger
n) -> (Int64 -> Int) -> Maybe Int64 -> Maybe Int
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (PSInteger -> Maybe Int64
forall a. FromPSInteger a => PSInteger -> Maybe a
fromPSInteger PSInteger
n :: Maybe Int64)
            Maybe Ps
_                  -> Maybe Int
forall a. Maybe a
Nothing
        Maybe Ps
_ -> Maybe Int
forall a. Maybe a
Nothing


-- | Check if a server version response matches any of the client's version specs.
-- The client sends compact version specs: word32 = [0, range, minor, major].
-- The server responds with: word32 = [0, 0, minor, major].
-- A server version matches a spec if the majors match and the server minor
-- falls in [spec_minor - range, spec_minor].
versionAccepted :: Word32 -> [Word32] -> Bool
versionAccepted :: Word32 -> [Word32] -> Bool
versionAccepted Word32
serverVer = (Word32 -> Bool) -> [Word32] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Word32 -> Bool
matches
  where
    serverMajor :: Word8
serverMajor = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
serverVer Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
    serverMinor :: Word8
serverMinor = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
serverVer Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
    matches :: Word32 -> Bool
matches Word32
spec =
      let specMajor :: Word8
specMajor = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
spec Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
          specMinor :: Word8
specMinor = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
spec Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
8) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
          specRange :: Word8
specRange = Word32 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
spec Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0xFF) :: Word8
      in Word8
serverMajor Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
specMajor
         Bool -> Bool -> Bool
&& Word8
serverMinor Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word8
specMinor
         Bool -> Bool -> Bool
&& Word8
serverMinor Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Word8
specMinor Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
- Word8
specRange


-- | Receive and decode a chunked BOLT response from the wire.
receiveResponse :: MonadPipe m => NC.Connection -> Int -> m Response
receiveResponse :: forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConn Int
tout = do
  bs :: BS.ByteString <- m (Maybe ByteString) -> m ByteString
forall (m :: * -> *) a. (Monad m, Monoid a) => m (Maybe a) -> m a
whileJustM (m (Maybe ByteString) -> m ByteString)
-> m (Maybe ByteString) -> m ByteString
forall a b. (a -> b) -> a -> b
$ do
          size :: Word16 <- Connection -> Int -> Int -> m Word16
forall a (m :: * -> *).
(HasEndianness a, MonadIO m, HasCallStack) =>
Connection -> Int -> Int -> m a
C.receiveBinary Connection
rawConn Int
tout Int
2
          if size == 0 then
            pure Nothing
          else
            Just <$> C.receiveBytestring rawConn tout (fromIntegral size)
  case PS.unpack' bs of
    Left Text
e         -> Error -> m Response
forall a. Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m Response) -> Error -> m Response
forall a b. (a -> b) -> a -> b
$ Text -> Error
CannotReadResponse Text
e
    Right Response
response -> Response -> m Response
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Response
response


-- | Serialize and send a BOLT request as chunked data on the wire.
sendRequest :: MonadPipe m => NC.Connection -> Int -> Request -> m ()
sendRequest :: forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConn Int
tout Request
request = Connection -> Int -> Ps -> m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Ps -> m ()
sendPs Connection
rawConn Int
tout (Request -> Ps
forall a. PackStream a => a -> Ps
PS.toPs Request
request)

sendPs :: MonadPipe m => NC.Connection -> Int -> Ps -> m ()
sendPs :: forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Ps -> m ()
sendPs Connection
rawConn Int
tout Ps
ps = do
  let bs :: ByteString
bs = Ps -> ByteString
forall a. PackStream a => a -> ByteString
PS.pack Ps
ps
  let chunks :: [ByteString]
chunks = HasCallStack => Int64 -> ByteString -> [ByteString]
Int64 -> ByteString -> [ByteString]
chunksOfBSL Int64
65_535 ByteString
bs
  let encoded :: ByteString
encoded = ByteString -> ByteString
BSL.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
forall a. Monoid a => [a] -> a
mconcat ((ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
addChunkHeader [ByteString]
chunks) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> [Word8] -> ByteString
BSL.pack [Word8
0, Word8
0]
  Connection -> Int -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> ByteString -> m ()
C.send Connection
rawConn Int
tout ByteString
encoded
  where addChunkHeader :: BSL.ByteString -> BSL.ByteString
        addChunkHeader :: ByteString -> ByteString
addChunkHeader ByteString
chunk = let size :: Word16
size = Int64 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int64
BSL.length ByteString
chunk) :: Word16
                                in Put () -> ByteString
forall a. Put a -> ByteString
runPutLazy (Word16 -> Put ()
forall a. Persist (BigEndian a) => a -> Put ()
putBE Word16
size) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
chunk


-- | Send a request through a 'Connection'.
flush :: MonadPipe m => Connection -> Request -> m ()
flush :: forall (m :: * -> *). MonadPipe m => Connection -> Request -> m ()
flush Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds} Request
request =
  Connection -> Int -> Request -> m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
request


-- | Receive the next response from a 'Connection'.
fetch :: MonadPipe m => Connection -> m Response
fetch :: forall (m :: * -> *). MonadPipe m => Connection -> m Response
fetch Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds} =
  Connection -> Int -> m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds


-- | Send LOGOFF, expect SUCCESS, transition to Authentication state.
-- Only valid on Bolt 5.1+ connections in Ready state.
logoff :: (MonadIO m, HasCallStack) => Connection -> m ()
logoff :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
logoff Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = ExceptT Error m () -> m ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error m () -> m ()) -> ExceptT Error m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Connection -> Int -> Request -> ExceptT Error m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RLogoff
  response <- Connection -> Int -> ExceptT Error m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    RSuccess HashMap Text Ps
_ -> IO () -> ExceptT Error m ()
forall a. IO a -> ExceptT Error m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error m ()) -> IO () -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Authentication
    Response
_          -> Error -> ExceptT Error m ()
forall a. Error -> ExceptT Error m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error m ()) -> Error -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Expected SUCCESS after LOGOFF"


-- | Send LOGON with credentials, expect SUCCESS, transition to Ready state.
-- Only valid on Bolt 5.1+ connections in Authentication state.
logon :: (MonadIO m, HasCallStack) => Connection -> Scheme -> m ()
logon :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Scheme -> m ()
logon Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} Scheme
authScheme = ExceptT Error m () -> m ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error m () -> m ()) -> ExceptT Error m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Connection -> Int -> Request -> ExceptT Error m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds (Request -> ExceptT Error m ()) -> Request -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Logon -> Request
RLogon (Logon -> Request) -> Logon -> Request
forall a b. (a -> b) -> a -> b
$ Scheme -> Logon
Logon Scheme
authScheme
  response <- Connection -> Int -> ExceptT Error m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    RSuccess HashMap Text Ps
_ -> IO () -> ExceptT Error m ()
forall a. IO a -> ExceptT Error m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error m ()) -> IO () -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Ready
    Response
_          -> Error -> ExceptT Error m ()
forall a. Error -> ExceptT Error m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError Error
AuthentificationFailed


-- | Send a TELEMETRY message if the server supports it.
-- No-op if telemetry is not enabled or version < 5.4.
sendTelemetry :: (MonadIO m, HasCallStack) => Connection -> TelemetryApi -> m ()
sendTelemetry :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> TelemetryApi -> m ()
sendTelemetry Connection{Bool
telemetry_enabled :: Connection -> Bool
telemetry_enabled :: Bool
telemetry_enabled, Word32
version :: Connection -> Word32
version :: Word32
version, Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds} TelemetryApi
api
  | Bool
telemetry_enabled Bool -> Bool -> Bool
&& Word32 -> Bool
supportsTelemetry Word32
version = ExceptT Error m () -> m ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error m () -> m ()) -> ExceptT Error m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
      Connection -> Int -> Request -> ExceptT Error m ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds (Request -> ExceptT Error m ()) -> Request -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ TelemetryApi -> Request
RTelemetry TelemetryApi
api
      response <- Connection -> Int -> ExceptT Error m Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
      case response of
        RSuccess HashMap Text Ps
_ -> () -> ExceptT Error m ()
forall a. a -> ExceptT Error m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        Response
_          -> Error -> ExceptT Error m ()
forall a. Error -> ExceptT Error m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error m ()) -> Error -> ExceptT Error m ()
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Expected SUCCESS after TELEMETRY"
  | Bool
otherwise = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()


-- ---------------------------------------------------------------------------
-- Connection accessors
-- ---------------------------------------------------------------------------

-- | Get the negotiated BOLT protocol version.
connectionVersion :: Connection -> Word32
connectionVersion :: Connection -> Word32
connectionVersion Connection{Word32
version :: Connection -> Word32
version :: Word32
version} = Word32
version

-- | Get the server agent string.
connectionAgent :: Connection -> T.Text
connectionAgent :: Connection -> Text
connectionAgent Connection{Text
server_agent :: Connection -> Text
server_agent :: Text
server_agent} = Text
server_agent

-- | Get the server-assigned connection ID.
connectionId :: Connection -> T.Text
connectionId :: Connection -> Text
connectionId Connection{Text
connection_id :: Connection -> Text
connection_id :: Text
connection_id} = Text
connection_id

-- | Get the server-advertised idle timeout in seconds, if any.
connectionServerIdleTimeout :: Connection -> Maybe Int
connectionServerIdleTimeout :: Connection -> Maybe Int
connectionServerIdleTimeout Connection{Maybe Int
serverIdleTimeout :: Connection -> Maybe Int
serverIdleTimeout :: Maybe Int
serverIdleTimeout} = Maybe Int
serverIdleTimeout

-- | Check whether the server supports telemetry.
connectionTelemetryEnabled :: Connection -> Bool
connectionTelemetryEnabled :: Connection -> Bool
connectionTelemetryEnabled Connection{Bool
telemetry_enabled :: Connection -> Bool
telemetry_enabled :: Bool
telemetry_enabled} = Bool
telemetry_enabled

-- | Update the connection's last-activity timestamp to now.
touchConnection :: MonadIO m => Connection -> m ()
touchConnection :: forall (m :: * -> *). MonadIO m => Connection -> m ()
touchConnection Connection{IORef Word64
lastActivity :: Connection -> IORef Word64
lastActivity :: IORef Word64
lastActivity} = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
  now <- IO Word64
getMonotonicTimeNSec
  writeIORef lastActivity now

-- | Get the monotonic timestamp (in nanoseconds) of the last connection activity.
connectionLastActivity :: MonadIO m => Connection -> m Word64
connectionLastActivity :: forall (m :: * -> *). MonadIO m => Connection -> m Word64
connectionLastActivity Connection{IORef Word64
lastActivity :: Connection -> IORef Word64
lastActivity :: IORef Word64
lastActivity} = IO Word64 -> m Word64
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Word64 -> m Word64) -> IO Word64 -> m Word64
forall a b. (a -> b) -> a -> b
$ IORef Word64 -> IO Word64
forall a. IORef a -> IO a
readIORef IORef Word64
lastActivity


-- ---------------------------------------------------------------------------
-- Transaction primitives (plain IO)
-- ---------------------------------------------------------------------------

-- | Begin an explicit transaction with bookmarks and access mode.
beginTx :: HasCallStack => Connection -> Begin -> IO ()
beginTx :: HasCallStack => Connection -> Begin -> IO ()
beginTx conn :: Connection
conn@Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} Begin
params = ExceptT Error IO () -> IO ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error IO () -> IO ()) -> ExceptT Error IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Connection -> [ServerState] -> Text -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> [ServerState] -> Text -> m ()
requireState Connection
conn [ServerState
Ready] Text
"BEGIN"
  Connection -> Int -> Request -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds (Request -> ExceptT Error IO ()) -> Request -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Begin -> Request
RBegin Begin
params
  response <- Connection -> Int -> ExceptT Error IO Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    RSuccess HashMap Text Ps
_ -> IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
TXready
    RFailure Failure{Text
code :: Text
code :: Failure -> Text
code, Text
message :: Text
message :: Failure -> Text
message} -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Failed
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO ()
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO ()) -> Error -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Error
ResponseErrorFailure Text
code Text
message
    Response
_ -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO ()
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO ()) -> Error -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Unexpected response to BEGIN"

-- | Commit the current transaction and return the bookmark (if any).
commitTx :: HasCallStack => Connection -> IO (Maybe T.Text)
commitTx :: HasCallStack => Connection -> IO (Maybe Text)
commitTx conn :: Connection
conn@Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = ExceptT Error IO (Maybe Text) -> IO (Maybe Text)
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error IO (Maybe Text) -> IO (Maybe Text))
-> ExceptT Error IO (Maybe Text) -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ do
  Connection -> [ServerState] -> Text -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> [ServerState] -> Text -> m ()
requireState Connection
conn [ServerState
TXready] Text
"COMMIT"
  Connection -> Int -> Request -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RCommit
  response <- Connection -> Int -> ExceptT Error IO Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    RSuccess HashMap Text Ps
meta -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Ready
      Maybe Text -> ExceptT Error IO (Maybe Text)
forall a. a -> ExceptT Error IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> ExceptT Error IO (Maybe Text))
-> Maybe Text -> ExceptT Error IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ HashMap Text Ps -> Maybe Text
extractBookmark HashMap Text Ps
meta
    RFailure Failure{Text
code :: Failure -> Text
code :: Text
code, Text
message :: Failure -> Text
message :: Text
message} -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Failed
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO (Maybe Text)
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO (Maybe Text))
-> Error -> ExceptT Error IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Error
ResponseErrorFailure Text
code Text
message
    Response
_ -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO (Maybe Text)
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO (Maybe Text))
-> Error -> ExceptT Error IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Unexpected response to COMMIT"

-- | Rollback the current transaction.
rollbackTx :: HasCallStack => Connection -> IO ()
rollbackTx :: HasCallStack => Connection -> IO ()
rollbackTx conn :: Connection
conn@Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds, IORef ServerState
server_state :: Connection -> IORef ServerState
server_state :: IORef ServerState
server_state} = ExceptT Error IO () -> IO ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error IO () -> IO ()) -> ExceptT Error IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
  Connection -> [ServerState] -> Text -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> [ServerState] -> Text -> m ()
requireState Connection
conn [ServerState
TXready] Text
"ROLLBACK"
  Connection -> Int -> Request -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
RRollback
  response <- Connection -> Int -> ExceptT Error IO Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds
  case response of
    RSuccess HashMap Text Ps
_ -> IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Ready
    RFailure Failure{Text
code :: Failure -> Text
code :: Text
code, Text
message :: Failure -> Text
message :: Text
message} -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ IORef ServerState -> ServerState -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ServerState
server_state ServerState
Failed
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO ()
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO ()) -> Error -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Error
ResponseErrorFailure Text
code Text
message
    Response
_ -> do
      IO () -> ExceptT Error IO ()
forall a. IO a -> ExceptT Error IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ExceptT Error IO ()) -> IO () -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
reset Connection
conn
      Error -> ExceptT Error IO ()
forall a. Error -> ExceptT Error IO a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> ExceptT Error IO ()) -> Error -> ExceptT Error IO ()
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Unexpected response to ROLLBACK"


-- | Try to rollback the current transaction, ignoring errors.
-- Used as cleanup in onException handlers.
tryRollback :: Connection -> IO ()
tryRollback :: Connection -> IO ()
tryRollback Connection
conn = do
  st <- Connection -> IO ServerState
forall (m :: * -> *). MonadIO m => Connection -> m ServerState
getState Connection
conn
  when (st == TXready || st == TXstreaming) $
    catch (rollbackTx conn) handler
  where
    handler :: Error -> IO ()
    handler :: Error -> IO ()
handler Error
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()


-- ---------------------------------------------------------------------------
-- Plain IO wire helpers
-- ---------------------------------------------------------------------------

-- | Check the connection state, throw InvalidState if wrong.
requireStateIO :: HasCallStack => Connection -> [ServerState] -> T.Text -> IO ()
requireStateIO :: HasCallStack => Connection -> [ServerState] -> Text -> IO ()
requireStateIO Connection
conn [ServerState]
allowed Text
action = do
  st <- Connection -> IO ServerState
forall (m :: * -> *). MonadIO m => Connection -> m ServerState
getState Connection
conn
  unless (st `elem` allowed) $ throwIO $ InvalidState st action

-- | Send a request, throwing on wire error.
flushIO :: HasCallStack => Connection -> Request -> IO ()
flushIO :: HasCallStack => Connection -> Request -> IO ()
flushIO Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds} Request
req =
  ExceptT Error IO () -> IO ()
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error IO () -> IO ()) -> ExceptT Error IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> Request -> ExceptT Error IO ()
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> Request -> m ()
sendRequest Connection
rawConnection Int
timeout_milliseconds Request
req

-- | Receive a response, throwing on wire error.
fetchIO :: HasCallStack => Connection -> IO Response
fetchIO :: HasCallStack => Connection -> IO Response
fetchIO Connection{Connection
rawConnection :: Connection -> Connection
rawConnection :: Connection
rawConnection, Int
timeout_milliseconds :: Connection -> Int
timeout_milliseconds :: Int
timeout_milliseconds} =
  ExceptT Error IO Response -> IO Response
forall e (m :: * -> *) b.
(Exception e, MonadIO m) =>
ExceptT e m b -> m b
exceptToThrow (ExceptT Error IO Response -> IO Response)
-> ExceptT Error IO Response -> IO Response
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> ExceptT Error IO Response
forall (m :: * -> *).
MonadPipe m =>
Connection -> Int -> m Response
receiveResponse Connection
rawConnection Int
timeout_milliseconds