-- | Cluster-aware routing: automatic server discovery and failover for Neo4j causal clusters.
module Database.Bolty.Routing
  ( -- * Routing table
    getRoutingTable
  , RoutingTable(..)
    -- * Access mode
  , AccessMode(..)
    -- * Routing pool
  , RoutingPool(..)
  , RoutingPoolConfig(..)
  , defaultRoutingPoolConfig
  , createRoutingPool
  , destroyRoutingPool
  , withRoutingConnection
  , acquireRoutingConnection
  , withRoutingTransaction
  , invalidateRoutingTable
    -- * Internals (exported for testing)
  , parseAddress
  ) where

import           Control.Concurrent             (threadDelay)
import           Control.Concurrent.MVar        (MVar, newMVar, withMVar)
import           Control.Exception              (SomeException, throwIO, try, fromException,
                                                 onException)
import           Control.Monad                  (when)
import           Data.HashMap.Lazy              (HashMap)
import qualified Data.HashMap.Lazy              as H
import           Data.IORef                     (IORef, newIORef, readIORef, writeIORef, atomicModifyIORef')
import           Data.Kind                      (Type)
import           Data.Text                      (Text)
import           Data.List                      (isInfixOf)
import qualified Data.Text                      as T
import qualified Data.Vector                    as V
import           Data.Word                      (Word16, Word64)
import           GHC.Clock                      (getMonotonicTimeNSec)
import           GHC.Stack                      (HasCallStack)

import           Database.Bolty.Connection.Type
import qualified Database.Bolty.Connection.Pipe as P
import           Database.Bolty.Message.Request (Request(..), Route(..), RouteExtra(..), Begin(..))
import           Database.Bolty.Message.Response (Response(..), Failure(..), RoutingTable(..), parseRoutingTable)
import           Database.Bolty.Pool


-- | Access mode for routing: determines whether to use reader or writer servers.
type AccessMode :: Type
data AccessMode = ReadAccess | WriteAccess
  deriving stock (Int -> AccessMode -> ShowS
[AccessMode] -> ShowS
AccessMode -> String
(Int -> AccessMode -> ShowS)
-> (AccessMode -> String)
-> ([AccessMode] -> ShowS)
-> Show AccessMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> AccessMode -> ShowS
showsPrec :: Int -> AccessMode -> ShowS
$cshow :: AccessMode -> String
show :: AccessMode -> String
$cshowList :: [AccessMode] -> ShowS
showList :: [AccessMode] -> ShowS
Show, AccessMode -> AccessMode -> Bool
(AccessMode -> AccessMode -> Bool)
-> (AccessMode -> AccessMode -> Bool) -> Eq AccessMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AccessMode -> AccessMode -> Bool
== :: AccessMode -> AccessMode -> Bool
$c/= :: AccessMode -> AccessMode -> Bool
/= :: AccessMode -> AccessMode -> Bool
Eq)


-- | Fetch a routing table from the server. The connection must be in Ready state.
getRoutingTable :: HasCallStack => Connection -> Maybe Text -> IO RoutingTable
getRoutingTable :: HasCallStack => Connection -> Maybe Text -> IO RoutingTable
getRoutingTable Connection
conn Maybe Text
dbName = do
  HasCallStack => Connection -> [ServerState] -> Text -> IO ()
Connection -> [ServerState] -> Text -> IO ()
P.requireStateIO Connection
conn [ServerState
Ready] Text
"ROUTE"
  HasCallStack => Connection -> Request -> IO ()
Connection -> Request -> IO ()
P.flushIO Connection
conn (Request -> IO ()) -> Request -> IO ()
forall a b. (a -> b) -> a -> b
$ Route -> Request
RRoute Route
    { routing :: HashMap Text Ps
routing   = HashMap Text Ps
forall k v. HashMap k v
H.empty
    , bookmarks :: Vector Text
bookmarks = Vector Text
forall a. Vector a
V.empty
    , extra :: RouteExtra
extra     = RouteExtra{db :: Maybe Text
db = Maybe Text
dbName, imp_user :: Maybe Text
imp_user = Maybe Text
forall a. Maybe a
Nothing}
    }
  response <- HasCallStack => Connection -> IO Response
Connection -> IO Response
P.fetchIO Connection
conn
  case response of
    RSuccess HashMap Text Ps
meta ->
      case HashMap Text Ps -> Either Text RoutingTable
parseRoutingTable HashMap Text Ps
meta of
        Right RoutingTable
rt -> RoutingTable -> IO RoutingTable
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure RoutingTable
rt
        Left Text
err -> Error -> IO RoutingTable
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO RoutingTable) -> Error -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ Text -> Error
RoutingTableError Text
err
    RFailure Failure{Text
code :: Text
code :: Failure -> Text
code, Text
message :: Text
message :: Failure -> Text
message} -> do
      Connection -> ServerState -> IO ()
forall (m :: * -> *).
MonadIO m =>
Connection -> ServerState -> m ()
P.setState Connection
conn ServerState
Failed
      Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
P.reset Connection
conn
      Error -> IO RoutingTable
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO RoutingTable) -> Error -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Error
ResponseErrorFailure Text
code Text
message
    Response
_ -> do
      Connection -> IO ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> m ()
P.reset Connection
conn
      Error -> IO RoutingTable
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO RoutingTable) -> Error -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ Text -> Error
WrongMessageFormat Text
"Unexpected response to ROUTE"


-- | Configuration for a routing-aware connection pool.
type RoutingPoolConfig :: Type
data RoutingPoolConfig = RoutingPoolConfig
  { RoutingPoolConfig -> PoolConfig
poolConfig    :: PoolConfig
  -- ^ Per-address pool configuration.
  , RoutingPoolConfig -> Maybe Text
routingDb     :: Maybe Text
  -- ^ Database to route for (Nothing = default database).
  , RoutingPoolConfig -> Int
refreshBuffer :: Int
  -- ^ Seconds before TTL expiry to proactively refresh the routing table.
  }

-- | Default routing pool configuration.
defaultRoutingPoolConfig :: RoutingPoolConfig
defaultRoutingPoolConfig :: RoutingPoolConfig
defaultRoutingPoolConfig = RoutingPoolConfig
  { poolConfig :: PoolConfig
poolConfig    = PoolConfig
defaultPoolConfig
  , routingDb :: Maybe Text
routingDb     = Maybe Text
forall a. Maybe a
Nothing
  , refreshBuffer :: Int
refreshBuffer = Int
10
  }


-- | Cached routing table with monotonic expiry time.
type CachedRoutingTable :: Type
data CachedRoutingTable = CachedRoutingTable
  { CachedRoutingTable -> RoutingTable
cachedTable  :: !RoutingTable
  , CachedRoutingTable -> Word64
expiresAtNs  :: !Word64
  }

-- | A routing-aware connection pool that directs connections based on access mode.
type RoutingPool :: Type
data RoutingPool = RoutingPool
  { RoutingPool -> ValidatedConfig
rpConfig       :: !ValidatedConfig
  , RoutingPool -> PoolConfig
rpPoolConfig   :: !PoolConfig
  , RoutingPool -> Maybe Text
rpRoutingDb    :: !(Maybe Text)
  , RoutingPool -> Int
rpRefreshBuf   :: !Int
  , RoutingPool -> IORef (Maybe CachedRoutingTable)
rpCacheRef     :: !(IORef (Maybe CachedRoutingTable))
  , RoutingPool -> IORef (HashMap Text BoltPool)
rpPoolsRef     :: !(IORef (HashMap Text BoltPool))
  , RoutingPool -> MVar ()
rpRefreshLock  :: !(MVar ())
  , RoutingPool -> IORef Int
rpCounter      :: !(IORef Int)
  }


-- | Create a routing-aware connection pool. Connects to the seed address,
-- fetches the initial routing table, and sets up per-address pools.
createRoutingPool :: ValidatedConfig -> RoutingPoolConfig -> IO RoutingPool
createRoutingPool :: ValidatedConfig -> RoutingPoolConfig -> IO RoutingPool
createRoutingPool ValidatedConfig
cfg RoutingPoolConfig{Int
Maybe Text
PoolConfig
poolConfig :: RoutingPoolConfig -> PoolConfig
routingDb :: RoutingPoolConfig -> Maybe Text
refreshBuffer :: RoutingPoolConfig -> Int
poolConfig :: PoolConfig
routingDb :: Maybe Text
refreshBuffer :: Int
..} = do
  let routingCfg :: ValidatedConfig
routingCfg = ValidatedConfig -> Routing -> ValidatedConfig
setRouting ValidatedConfig
cfg Routing
Routing
  cacheRef <- Maybe CachedRoutingTable -> IO (IORef (Maybe CachedRoutingTable))
forall a. a -> IO (IORef a)
newIORef Maybe CachedRoutingTable
forall a. Maybe a
Nothing
  poolsRef <- newIORef H.empty
  lock     <- newMVar ()
  counter  <- newIORef 0
  let rp = RoutingPool
        { rpConfig :: ValidatedConfig
rpConfig     = ValidatedConfig
routingCfg
        , rpPoolConfig :: PoolConfig
rpPoolConfig = PoolConfig
poolConfig
        , rpRoutingDb :: Maybe Text
rpRoutingDb  = Maybe Text
routingDb
        , rpRefreshBuf :: Int
rpRefreshBuf = Int
refreshBuffer
        , rpCacheRef :: IORef (Maybe CachedRoutingTable)
rpCacheRef   = IORef (Maybe CachedRoutingTable)
cacheRef
        , rpPoolsRef :: IORef (HashMap Text BoltPool)
rpPoolsRef   = IORef (HashMap Text BoltPool)
poolsRef
        , rpRefreshLock :: MVar ()
rpRefreshLock = MVar ()
lock
        , rpCounter :: IORef Int
rpCounter    = IORef Int
counter
        }
  -- Bootstrap: fetch routing table from seed
  _ <- refreshRoutingTable rp
  pure rp


-- | Destroy all per-address pools in the routing pool.
destroyRoutingPool :: RoutingPool -> IO ()
destroyRoutingPool :: RoutingPool -> IO ()
destroyRoutingPool RoutingPool{IORef (HashMap Text BoltPool)
rpPoolsRef :: RoutingPool -> IORef (HashMap Text BoltPool)
rpPoolsRef :: IORef (HashMap Text BoltPool)
rpPoolsRef} = do
  pools <- IORef (HashMap Text BoltPool) -> IO (HashMap Text BoltPool)
forall a. IORef a -> IO a
readIORef IORef (HashMap Text BoltPool)
rpPoolsRef
  mapM_ destroyPool pools
  writeIORef rpPoolsRef H.empty


-- | Invalidate the cached routing table, forcing a refresh on the next operation.
-- Use this when a routing error (e.g. NotALeader) indicates the table is stale.
invalidateRoutingTable :: RoutingPool -> IO ()
invalidateRoutingTable :: RoutingPool -> IO ()
invalidateRoutingTable RoutingPool{IORef (Maybe CachedRoutingTable)
rpCacheRef :: RoutingPool -> IORef (Maybe CachedRoutingTable)
rpCacheRef :: IORef (Maybe CachedRoutingTable)
rpCacheRef} = IORef (Maybe CachedRoutingTable)
-> Maybe CachedRoutingTable -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe CachedRoutingTable)
rpCacheRef Maybe CachedRoutingTable
forall a. Maybe a
Nothing


-- | Pick the next address from a vector using round-robin.
roundRobin :: RoutingPool -> V.Vector Text -> IO Text
roundRobin :: RoutingPool -> Vector Text -> IO Text
roundRobin RoutingPool{IORef Int
rpCounter :: RoutingPool -> IORef Int
rpCounter :: IORef Int
rpCounter} Vector Text
addrs = do
  idx <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
rpCounter ((Int -> (Int, Int)) -> IO Int) -> (Int -> (Int, Int)) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
n)
  pure $ addrs V.! (idx `mod` V.length addrs)


-- | Acquire a connection routed by access mode, run an action, then release.
-- On connection failure, tries the next address in round-robin order
-- until all addresses are exhausted.
withRoutingConnection :: HasCallStack => RoutingPool -> AccessMode -> (Connection -> IO a) -> IO a
withRoutingConnection :: forall a.
HasCallStack =>
RoutingPool -> AccessMode -> (Connection -> IO a) -> IO a
withRoutingConnection RoutingPool
rp AccessMode
mode Connection -> IO a
action = do
  rt <- RoutingPool -> IO RoutingTable
getOrRefreshTable RoutingPool
rp
  let addrs = case AccessMode
mode of
        AccessMode
ReadAccess  -> RoutingTable -> Vector Text
readers RoutingTable
rt
        AccessMode
WriteAccess -> RoutingTable -> Vector Text
writers RoutingTable
rt
  when (V.null addrs) $
    throwIO $ RoutingTableError $ "No servers available for " <> T.pack (show mode)
  tryAddresses rp addrs (V.length addrs) Nothing action


-- | Acquire a routed connection by access mode. Returns a
-- 'CheckedOutConnection' that must be released by the caller.
-- Tries addresses in round-robin order, failing over on unavailable servers.
acquireRoutingConnection :: HasCallStack => RoutingPool -> AccessMode -> IO CheckedOutConnection
acquireRoutingConnection :: HasCallStack =>
RoutingPool -> AccessMode -> IO CheckedOutConnection
acquireRoutingConnection RoutingPool
rp AccessMode
mode = do
  rt <- RoutingPool -> IO RoutingTable
getOrRefreshTable RoutingPool
rp
  let addrs = case AccessMode
mode of
        AccessMode
ReadAccess  -> RoutingTable -> Vector Text
readers RoutingTable
rt
        AccessMode
WriteAccess -> RoutingTable -> Vector Text
writers RoutingTable
rt
  when (V.null addrs) $
    throwIO $ RoutingTableError $ "No servers available for " <> T.pack (show mode)
  tryAcquireAddresses rp addrs (V.length addrs) Nothing


-- | Try addresses in round-robin order for acquire, failing over on connection errors.
tryAcquireAddresses :: HasCallStack
                    => RoutingPool -> V.Vector Text -> Int -> Maybe SomeException -> IO CheckedOutConnection
tryAcquireAddresses :: HasCallStack =>
RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> IO CheckedOutConnection
tryAcquireAddresses RoutingPool
_rp Vector Text
_addrs Int
0 (Just SomeException
lastErr) = SomeException -> IO CheckedOutConnection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
lastErr
tryAcquireAddresses RoutingPool
_rp Vector Text
_addrs Int
0 Maybe SomeException
Nothing =
  Error -> IO CheckedOutConnection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO CheckedOutConnection)
-> Error -> IO CheckedOutConnection
forall a b. (a -> b) -> a -> b
$ Text -> Error
RoutingTableError Text
"All servers unavailable"
tryAcquireAddresses RoutingPool
rp Vector Text
addrs Int
remaining Maybe SomeException
_lastErr = do
  addr <- RoutingPool -> Vector Text -> IO Text
roundRobin RoutingPool
rp Vector Text
addrs
  pool <- getOrCreatePool rp addr
  result <- try $ acquireConnection pool
  case result of
    Right CheckedOutConnection
coc -> CheckedOutConnection -> IO CheckedOutConnection
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CheckedOutConnection
coc
    Left (SomeException
e :: SomeException)
      | Int
remaining Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1, SomeException -> Bool
isServerUnavailable SomeException
e ->
          HasCallStack =>
RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> IO CheckedOutConnection
RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> IO CheckedOutConnection
tryAcquireAddresses RoutingPool
rp Vector Text
addrs (Int
remaining Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e)
      | Bool
otherwise -> SomeException -> IO CheckedOutConnection
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
e


-- | Try addresses in round-robin order, failing over on connection errors.
tryAddresses :: HasCallStack
             => RoutingPool -> V.Vector Text -> Int -> Maybe SomeException -> (Connection -> IO a) -> IO a
tryAddresses :: forall a.
HasCallStack =>
RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> (Connection -> IO a)
-> IO a
tryAddresses RoutingPool
_rp Vector Text
_addrs Int
0 (Just SomeException
lastErr) Connection -> IO a
_action = SomeException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
lastErr
tryAddresses RoutingPool
_rp Vector Text
_addrs Int
0 Maybe SomeException
Nothing Connection -> IO a
_action =
  Error -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO a) -> Error -> IO a
forall a b. (a -> b) -> a -> b
$ Text -> Error
RoutingTableError Text
"All servers unavailable"
tryAddresses RoutingPool
rp Vector Text
addrs Int
remaining Maybe SomeException
_lastErr Connection -> IO a
action = do
  addr <- RoutingPool -> Vector Text -> IO Text
roundRobin RoutingPool
rp Vector Text
addrs
  pool <- getOrCreatePool rp addr
  result <- try $ withConnection pool action
  case result of
    Right a
x -> a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
    Left (SomeException
e :: SomeException)
      | Int
remaining Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1, SomeException -> Bool
isServerUnavailable SomeException
e ->
          RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> (Connection -> IO a)
-> IO a
forall a.
HasCallStack =>
RoutingPool
-> Vector Text
-> Int
-> Maybe SomeException
-> (Connection -> IO a)
-> IO a
tryAddresses RoutingPool
rp Vector Text
addrs (Int
remaining Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
e) Connection -> IO a
action
      | Bool
otherwise -> SomeException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
e


-- | Check if an exception indicates the server is unavailable (connection refused,
-- broken pipe, no healthy connection in pool, etc.) as opposed to a query-level error.
isServerUnavailable :: SomeException -> Bool
isServerUnavailable :: SomeException -> Bool
isServerUnavailable SomeException
e = case SomeException -> Maybe Error
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e :: Maybe Error of
  Just (NonboltyError SomeException
_) -> Bool
True   -- IO exceptions (connection refused, broken pipe, etc.)
  Maybe Error
_ -> String
"no healthy connection" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` SomeException -> String
forall a. Show a => a -> String
show SomeException
e


-- | Run a retrying transaction routed by access mode.
-- Re-acquires routing table and connection on each retry attempt, so that
-- routing errors (NotALeader) and transient errors trigger fresh routing.
withRoutingTransaction :: HasCallStack => RoutingPool -> AccessMode -> (Connection -> IO a) -> IO a
withRoutingTransaction :: forall a.
HasCallStack =>
RoutingPool -> AccessMode -> (Connection -> IO a) -> IO a
withRoutingTransaction RoutingPool
rp AccessMode
mode Connection -> IO a
action =
  let rc :: RetryConfig
rc = PoolConfig -> RetryConfig
retryConfig (RoutingPool -> PoolConfig
rpPoolConfig RoutingPool
rp)
      maxR :: Int
maxR = RetryConfig -> Int
maxRetries RetryConfig
rc
      initD :: Int
initD = RetryConfig -> Int
initialDelay RetryConfig
rc
      maxD :: Int
maxD = RetryConfig -> Int
maxDelay RetryConfig
rc
  in Int -> Int -> Int -> IO a
go Int
maxR Int
initD Int
maxD
  where
    go :: Int -> Int -> Int -> IO a
go Int
0 Int
_ Int
_    = IO a
attempt
    go Int
n Int
delay Int
maxD = do
      result <- IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
attempt
      case result of
        Right a
x -> a -> IO a
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
        Left (SomeException
e :: SomeException) -> case SomeException -> Maybe Error
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e :: Maybe Error of
          Just Error
err
            | Error -> Bool
isTransient Error
err -> do
                Int -> IO ()
threadDelay Int
delay
                Int -> Int -> Int -> IO a
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
maxD (Int
delay Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)) Int
maxD
            | Error -> Bool
isRoutingError Error
err -> do
                RoutingPool -> IO ()
invalidateRoutingTable RoutingPool
rp
                Int -> IO ()
threadDelay Int
delay
                Int -> Int -> Int -> IO a
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
maxD (Int
delay Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)) Int
maxD
          Maybe Error
_ -> SomeException -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO SomeException
e

    modeChar :: Char
modeChar = case AccessMode
mode of
      AccessMode
ReadAccess  -> Char
'r'
      AccessMode
WriteAccess -> Char
'w'

    attempt :: IO a
attempt = do
      rt <- RoutingPool -> IO RoutingTable
getOrRefreshTable RoutingPool
rp
      let addrs = case AccessMode
mode of
            AccessMode
ReadAccess  -> RoutingTable -> Vector Text
readers RoutingTable
rt
            AccessMode
WriteAccess -> RoutingTable -> Vector Text
writers RoutingTable
rt
      when (V.null addrs) $
        throwIO $ RoutingTableError $ "No servers available for " <> T.pack (show mode)
      addr <- roundRobin rp addrs
      pool <- getOrCreatePool rp addr
      withConnection pool $ \Connection
conn -> do
        HasCallStack => Connection -> Begin -> IO ()
Connection -> Begin -> IO ()
P.beginTx Connection
conn (Begin -> IO ()) -> Begin -> IO ()
forall a b. (a -> b) -> a -> b
$ Vector Text
-> Maybe Int64
-> HashMap Text Ps
-> Char
-> Maybe Text
-> Maybe Text
-> Begin
Begin Vector Text
forall a. Vector a
V.empty Maybe Int64
forall a. Maybe a
Nothing HashMap Text Ps
forall k v. HashMap k v
H.empty Char
modeChar Maybe Text
forall a. Maybe a
Nothing Maybe Text
forall a. Maybe a
Nothing
        result <- Connection -> IO a
action Connection
conn IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`onException` Connection -> IO ()
P.tryRollback Connection
conn
        _ <- P.commitTx conn
        pure result


-- | Get the cached routing table, refreshing if expired.
getOrRefreshTable :: RoutingPool -> IO RoutingTable
getOrRefreshTable :: RoutingPool -> IO RoutingTable
getOrRefreshTable rp :: RoutingPool
rp@RoutingPool{IORef (Maybe CachedRoutingTable)
rpCacheRef :: RoutingPool -> IORef (Maybe CachedRoutingTable)
rpCacheRef :: IORef (Maybe CachedRoutingTable)
rpCacheRef, Int
rpRefreshBuf :: RoutingPool -> Int
rpRefreshBuf :: Int
rpRefreshBuf} = do
  cached <- IORef (Maybe CachedRoutingTable) -> IO (Maybe CachedRoutingTable)
forall a. IORef a -> IO a
readIORef IORef (Maybe CachedRoutingTable)
rpCacheRef
  now <- getMonotonicTimeNSec
  let bufferNs = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
rpRefreshBuf Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
1_000_000_000
  case cached of
    Just CachedRoutingTable{RoutingTable
cachedTable :: CachedRoutingTable -> RoutingTable
cachedTable :: RoutingTable
cachedTable, Word64
expiresAtNs :: CachedRoutingTable -> Word64
expiresAtNs :: Word64
expiresAtNs}
      | Word64
now Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
bufferNs Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
expiresAtNs -> RoutingTable -> IO RoutingTable
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure RoutingTable
cachedTable
    Maybe CachedRoutingTable
_ -> RoutingPool -> IO RoutingTable
refreshRoutingTable RoutingPool
rp


-- | Refresh the routing table. Tries known router addresses from the cached
-- table first, then falls back to the seed address from the initial config.
refreshRoutingTable :: RoutingPool -> IO RoutingTable
refreshRoutingTable :: RoutingPool -> IO RoutingTable
refreshRoutingTable rp :: RoutingPool
rp@RoutingPool{ValidatedConfig
rpConfig :: RoutingPool -> ValidatedConfig
rpConfig :: ValidatedConfig
rpConfig, IORef (Maybe CachedRoutingTable)
rpCacheRef :: RoutingPool -> IORef (Maybe CachedRoutingTable)
rpCacheRef :: IORef (Maybe CachedRoutingTable)
rpCacheRef, MVar ()
rpRefreshLock :: RoutingPool -> MVar ()
rpRefreshLock :: MVar ()
rpRefreshLock} =
  MVar () -> (() -> IO RoutingTable) -> IO RoutingTable
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar ()
rpRefreshLock ((() -> IO RoutingTable) -> IO RoutingTable)
-> (() -> IO RoutingTable) -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ \()
_ -> do
    -- Double-check after acquiring lock (another thread may have refreshed)
    cached <- IORef (Maybe CachedRoutingTable) -> IO (Maybe CachedRoutingTable)
forall a. IORef a -> IO a
readIORef IORef (Maybe CachedRoutingTable)
rpCacheRef
    now <- getMonotonicTimeNSec
    let bufferNs = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (RoutingPool -> Int
rpRefreshBuf RoutingPool
rp) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
1_000_000_000
    case cached of
      Just CachedRoutingTable{RoutingTable
cachedTable :: CachedRoutingTable -> RoutingTable
cachedTable :: RoutingTable
cachedTable, Word64
expiresAtNs :: CachedRoutingTable -> Word64
expiresAtNs :: Word64
expiresAtNs}
        | Word64
now Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
bufferNs Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
< Word64
expiresAtNs -> RoutingTable -> IO RoutingTable
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure RoutingTable
cachedTable
      Maybe CachedRoutingTable
_ -> do
        -- Build list of addresses to try: known routers first, seed last
        let knownRouters :: [Text]
knownRouters = case Maybe CachedRoutingTable
cached of
              Just CachedRoutingTable{RoutingTable
cachedTable :: CachedRoutingTable -> RoutingTable
cachedTable :: RoutingTable
cachedTable} -> Vector Text -> [Text]
forall a. Vector a -> [a]
V.toList (RoutingTable -> Vector Text
routers RoutingTable
cachedTable)
              Maybe CachedRoutingTable
Nothing -> []
        let seedAddr :: Text
seedAddr = ValidatedConfig -> Text
seedAddress ValidatedConfig
rpConfig
        let allAddrs :: [Text]
allAddrs = [Text] -> [Text]
dedupAddrs ([Text]
knownRouters [Text] -> [Text] -> [Text]
forall a. Semigroup a => a -> a -> a
<> [Text
seedAddr])
        RoutingPool -> [Text] -> Word64 -> IO RoutingTable
tryRefreshFrom RoutingPool
rp [Text]
allAddrs Word64
now


-- | Try to refresh the routing table from a list of router addresses.
-- Tries each in order, falling through on failure.
tryRefreshFrom :: RoutingPool -> [Text] -> Word64 -> IO RoutingTable
tryRefreshFrom :: RoutingPool -> [Text] -> Word64 -> IO RoutingTable
tryRefreshFrom RoutingPool
_rp [] Word64
_now =
  Error -> IO RoutingTable
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO RoutingTable) -> Error -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ Text -> Error
RoutingTableError Text
"Could not reach any router to refresh routing table"
tryRefreshFrom rp :: RoutingPool
rp@RoutingPool{ValidatedConfig
rpConfig :: RoutingPool -> ValidatedConfig
rpConfig :: ValidatedConfig
rpConfig, Maybe Text
rpRoutingDb :: RoutingPool -> Maybe Text
rpRoutingDb :: Maybe Text
rpRoutingDb, IORef (Maybe CachedRoutingTable)
rpCacheRef :: RoutingPool -> IORef (Maybe CachedRoutingTable)
rpCacheRef :: IORef (Maybe CachedRoutingTable)
rpCacheRef} (Text
addr:[Text]
rest) Word64
now = do
  let (Text
h, Word16
p) = Text -> (Text, Word16)
parseAddress Text
addr
  let cfg :: ValidatedConfig
cfg = ValidatedConfig -> Text -> Word16 -> ValidatedConfig
setHostPort ValidatedConfig
rpConfig Text
h Word16
p
  result <- IO RoutingTable -> IO (Either SomeException RoutingTable)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO RoutingTable -> IO (Either SomeException RoutingTable))
-> IO RoutingTable -> IO (Either SomeException RoutingTable)
forall a b. (a -> b) -> a -> b
$ do
    conn <- ValidatedConfig -> IO Connection
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ValidatedConfig -> m Connection
P.connect ValidatedConfig
cfg
    rt <- getRoutingTable conn rpRoutingDb
    P.close conn
    pure rt
  case result of
    Right RoutingTable
rt -> do
      let expiresAt :: Word64
expiresAt = Word64
now Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (RoutingTable -> Int64
ttl RoutingTable
rt) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
1_000_000_000
      IORef (Maybe CachedRoutingTable)
-> Maybe CachedRoutingTable -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe CachedRoutingTable)
rpCacheRef (Maybe CachedRoutingTable -> IO ())
-> Maybe CachedRoutingTable -> IO ()
forall a b. (a -> b) -> a -> b
$ CachedRoutingTable -> Maybe CachedRoutingTable
forall a. a -> Maybe a
Just CachedRoutingTable{cachedTable :: RoutingTable
cachedTable = RoutingTable
rt, expiresAtNs :: Word64
expiresAtNs = Word64
expiresAt}
      RoutingTable -> IO RoutingTable
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure RoutingTable
rt
    Left (SomeException
_ :: SomeException)
      | Bool -> Bool
not ([Text] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Text]
rest) -> RoutingPool -> [Text] -> Word64 -> IO RoutingTable
tryRefreshFrom RoutingPool
rp [Text]
rest Word64
now
      | Bool
otherwise -> Error -> IO RoutingTable
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (Error -> IO RoutingTable) -> Error -> IO RoutingTable
forall a b. (a -> b) -> a -> b
$ Text -> Error
RoutingTableError Text
"Could not reach any router to refresh routing table"


-- | Remove duplicate addresses while preserving order.
dedupAddrs :: [Text] -> [Text]
dedupAddrs :: [Text] -> [Text]
dedupAddrs = [Text] -> [Text] -> [Text]
go []
  where
    go :: [Text] -> [Text] -> [Text]
    go :: [Text] -> [Text] -> [Text]
go [Text]
_seen [] = []
    go [Text]
seen (Text
x:[Text]
xs)
      | Text
x Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
seen = [Text] -> [Text] -> [Text]
go [Text]
seen [Text]
xs
      | Bool
otherwise      = Text
x Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
: [Text] -> [Text] -> [Text]
go (Text
xText -> [Text] -> [Text]
forall a. a -> [a] -> [a]
:[Text]
seen) [Text]
xs


-- | Get or create a pool for a given address.
getOrCreatePool :: RoutingPool -> Text -> IO BoltPool
getOrCreatePool :: RoutingPool -> Text -> IO BoltPool
getOrCreatePool RoutingPool{ValidatedConfig
rpConfig :: RoutingPool -> ValidatedConfig
rpConfig :: ValidatedConfig
rpConfig, PoolConfig
rpPoolConfig :: RoutingPool -> PoolConfig
rpPoolConfig :: PoolConfig
rpPoolConfig, IORef (HashMap Text BoltPool)
rpPoolsRef :: RoutingPool -> IORef (HashMap Text BoltPool)
rpPoolsRef :: IORef (HashMap Text BoltPool)
rpPoolsRef} Text
addr = do
  pools <- IORef (HashMap Text BoltPool) -> IO (HashMap Text BoltPool)
forall a. IORef a -> IO a
readIORef IORef (HashMap Text BoltPool)
rpPoolsRef
  case H.lookup addr pools of
    Just BoltPool
pool -> BoltPool -> IO BoltPool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BoltPool
pool
    Maybe BoltPool
Nothing   -> do
      let (Text
h, Word16
p) = Text -> (Text, Word16)
parseAddress Text
addr
      let cfg :: ValidatedConfig
cfg = ValidatedConfig -> Text -> Word16 -> ValidatedConfig
setHostPort ValidatedConfig
rpConfig Text
h Word16
p
      pool <- HasCallStack => ValidatedConfig -> PoolConfig -> IO BoltPool
ValidatedConfig -> PoolConfig -> IO BoltPool
createPool ValidatedConfig
cfg PoolConfig
rpPoolConfig
      atomicModifyIORef' rpPoolsRef $ \HashMap Text BoltPool
ps -> (Text -> BoltPool -> HashMap Text BoltPool -> HashMap Text BoltPool
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
H.insert Text
addr BoltPool
pool HashMap Text BoltPool
ps, ())
      pure pool


-- | Parse "host:port" into (host, port). Falls back to default port 7687.
parseAddress :: Text -> (Text, Word16)
parseAddress :: Text -> (Text, Word16)
parseAddress Text
addr =
  case HasCallStack => Text -> Text -> [Text]
Text -> Text -> [Text]
T.splitOn Text
":" Text
addr of
    [Text
h, Text
p] -> case ReadS Word16
forall a. Read a => ReadS a
reads (Text -> String
T.unpack Text
p) of
                [(Word16
port, String
"")] -> (Text
h, Word16
port)
                [(Word16, String)]
_            -> (Text
h, Word16
7687)
    [Text]
_       -> (Text
addr, Word16
7687)


-- | Set the routing field on a ValidatedConfig (avoids ambiguous record update).
setRouting :: ValidatedConfig -> Routing -> ValidatedConfig
setRouting :: ValidatedConfig -> Routing -> ValidatedConfig
setRouting ValidatedConfig{Text
host :: Text
host :: ValidatedConfig -> Text
host, Word16
port :: Word16
port :: ValidatedConfig -> Word16
port, Scheme
scheme :: Scheme
scheme :: ValidatedConfig -> Scheme
scheme, Bool
use_tls :: Bool
use_tls :: ValidatedConfig -> Bool
use_tls, [Word32]
versions :: [Word32]
versions :: ValidatedConfig -> [Word32]
versions, Int
timeout :: Int
timeout :: ValidatedConfig -> Int
timeout, UserAgent
user_agent :: UserAgent
user_agent :: ValidatedConfig -> UserAgent
user_agent, Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: ValidatedConfig -> Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger, Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler :: ValidatedConfig -> Maybe (Notification -> IO ())
notificationHandler} Routing
r =
  ValidatedConfig{Text
host :: Text
host :: Text
host, Word16
port :: Word16
port :: Word16
port, Scheme
scheme :: Scheme
scheme :: Scheme
scheme, Bool
use_tls :: Bool
use_tls :: Bool
use_tls, [Word32]
versions :: [Word32]
versions :: [Word32]
versions, Int
timeout :: Int
timeout :: Int
timeout, routing :: Routing
routing = Routing
r, UserAgent
user_agent :: UserAgent
user_agent :: UserAgent
user_agent, Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger, Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler}


-- | Set host and port on a ValidatedConfig (avoids ambiguous record update).
setHostPort :: ValidatedConfig -> Text -> Word16 -> ValidatedConfig
setHostPort :: ValidatedConfig -> Text -> Word16 -> ValidatedConfig
setHostPort ValidatedConfig{Scheme
scheme :: ValidatedConfig -> Scheme
scheme :: Scheme
scheme, Bool
use_tls :: ValidatedConfig -> Bool
use_tls :: Bool
use_tls, [Word32]
versions :: ValidatedConfig -> [Word32]
versions :: [Word32]
versions, Int
timeout :: ValidatedConfig -> Int
timeout :: Int
timeout, Routing
routing :: ValidatedConfig -> Routing
routing :: Routing
routing, UserAgent
user_agent :: ValidatedConfig -> UserAgent
user_agent :: UserAgent
user_agent, Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: ValidatedConfig -> Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger, Maybe (Notification -> IO ())
notificationHandler :: ValidatedConfig -> Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler} Text
h Word16
p =
  ValidatedConfig{host :: Text
host = Text
h, port :: Word16
port = Word16
p, Scheme
scheme :: Scheme
scheme :: Scheme
scheme, Bool
use_tls :: Bool
use_tls :: Bool
use_tls, [Word32]
versions :: [Word32]
versions :: [Word32]
versions, Int
timeout :: Int
timeout :: Int
timeout, Routing
routing :: Routing
routing :: Routing
routing, UserAgent
user_agent :: UserAgent
user_agent :: UserAgent
user_agent, Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger :: Maybe (QueryLog -> QueryMeta -> IO ())
queryLogger, Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler :: Maybe (Notification -> IO ())
notificationHandler}


-- | Extract "host:port" from a ValidatedConfig (avoids ambiguous record field access).
seedAddress :: ValidatedConfig -> Text
seedAddress :: ValidatedConfig -> Text
seedAddress ValidatedConfig{host :: ValidatedConfig -> Text
host = Text
h, port :: ValidatedConfig -> Word16
port = Word16
p} = Text
h Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
":" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
T.pack (Word16 -> String
forall a. Show a => a -> String
show Word16
p)