-- | Internal module. Not part of the public API.
module Database.Bolty.Connection.Connection
  ( connect
  , close
  , send
  , receiveBytestring
  , receiveBinary
  ) where

import           Control.Monad.Trans    (MonadIO(..))
import           Data.Default           (Default(..))
import           Data.Persist           (HasEndianness, runGet, getBE)
import           Data.Word              (Word16)
import           GHC.Stack              (HasCallStack, withFrozenCallStack)
import           Network.Connection     (Connection, ConnectionParams(..), connectFromSocket, connectionClose,
                                        connectionGetExact, connectionPut,
                                        initConnectionContext)
import qualified Data.ByteString        as BS
import qualified Data.Text              as T
import qualified Network.Socket         as NS
import qualified System.Timeout         as ST (timeout)
import           TextShow               (TextShow, showt)


timeoutThrow :: (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow :: forall b a. (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow Int
milliseconds b
associated_data IO a
action = (HasCallStack => IO a) -> IO a
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => IO a) -> IO a) -> (HasCallStack => IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ do
  res <- Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
ST.timeout (Int
milliseconds Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1_000) IO a
action
  case res of
    Just a
a  -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    Maybe a
Nothing -> [Char] -> IO a
forall a. [Char] -> IO a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> IO a) -> [Char] -> IO a
forall a b. (a -> b) -> a -> b
$ [Char]
"Timeout: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
forall a. Show a => a -> [Char]
show (b -> Text
forall a. TextShow a => a -> Text
showt b
associated_data)


-- | Open a raw TCP connection to the given host and port with optional TLS.
connect :: (MonadIO m, HasCallStack) => Bool -> T.Text -> Word16 -> Int -> m (Connection, Int)
connect :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Bool -> Text -> Word16 -> Int -> m (Connection, Int)
connect Bool
use_tls Text
host Word16
port Int
timeout = IO (Connection, Int) -> m (Connection, Int)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Connection, Int) -> m (Connection, Int))
-> IO (Connection, Int) -> m (Connection, Int)
forall a b. (a -> b) -> a -> b
$ do
  ctx <- IO ConnectionContext
initConnectionContext
  let params = ConnectionParams
        { connectionHostname :: [Char]
connectionHostname  = Text -> [Char]
T.unpack Text
host
        , connectionPort :: PortNumber
connectionPort      = Word16 -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
port
        , connectionUseSecure :: Maybe TLSSettings
connectionUseSecure = if Bool
use_tls then TLSSettings -> Maybe TLSSettings
forall a. a -> Maybe a
Just TLSSettings
forall a. Default a => a
def else Maybe TLSSettings
forall a. Maybe a
Nothing
        , connectionUseSocks :: Maybe ProxySettings
connectionUseSocks  = Maybe ProxySettings
forall a. Maybe a
Nothing
        }
  let hints = AddrInfo
NS.defaultHints { NS.addrSocketType = NS.Stream }
  addrs <- NS.getAddrInfo (Just hints) (Just $ T.unpack host) (Just $ show port)
  case addrs of
    [] -> [Char] -> IO (Connection, Int)
forall a. [Char] -> IO a
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail ([Char] -> IO (Connection, Int)) -> [Char] -> IO (Connection, Int)
forall a b. (a -> b) -> a -> b
$ [Char]
"Cannot resolve " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack Text
host
    (AddrInfo
addr:[AddrInfo]
_) -> do
      sock <- AddrInfo -> IO Socket
NS.openSocket AddrInfo
addr
      NS.setSocketOption sock NS.NoDelay 1
      timeoutThrow timeout ("connecting to " <> host <> ":" <> showt port) $ do
        NS.connect sock (NS.addrAddress addr)
        conn <- connectFromSocket ctx sock params
        pure (conn, timeout)


-- | Close a raw TCP connection.
close :: (MonadIO m, HasCallStack) => Connection -> Int -> m ()
close :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> m ()
close Connection
conn Int
timeout = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> Text -> IO () -> IO ()
forall b a. (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow Int
timeout (Text
"closing connection" :: T.Text) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> IO ()
connectionClose Connection
conn


-- | Receive exactly @size@ bytes from the connection with a timeout.
receiveBytestring :: (MonadIO m, HasCallStack) => Connection -> Int -> Int -> m BS.ByteString
receiveBytestring :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> Int -> m ByteString
receiveBytestring Connection
conn Int
timeout Int
size =
  IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Text -> IO ByteString -> IO ByteString
forall b a. (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow Int
timeout (Text
"receiving " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. TextShow a => a -> Text
showt Int
size Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" bytes") (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> IO ByteString
connectionGetExact Connection
conn Int
size


-- | Receive and decode a big-endian binary value from the connection.
receiveBinary :: forall a m. (HasEndianness a, MonadIO m, HasCallStack) => Connection -> Int -> Int -> m a
receiveBinary :: forall a (m :: * -> *).
(HasEndianness a, MonadIO m, HasCallStack) =>
Connection -> Int -> Int -> m a
receiveBinary Connection
conn Int
timeout Int
size =
  (ByteString -> a) -> m ByteString -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> a
decodeBE (m ByteString -> m a) -> m ByteString -> m a
forall a b. (a -> b) -> a -> b
$ IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Text -> IO ByteString -> IO ByteString
forall b a. (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow Int
timeout (Text
"receiving data" :: T.Text) (IO ByteString -> IO ByteString) -> IO ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ Connection -> Int -> IO ByteString
connectionGetExact Connection
conn Int
size
  where
    decodeBE :: ByteString -> a
decodeBE ByteString
bs = case Get a -> ByteString -> Either [Char] a
forall a. Get a -> ByteString -> Either [Char] a
runGet (forall a. Persist (BigEndian a) => Get a
getBE @a) ByteString
bs of
      Left [Char]
e  -> [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"receiveBinary: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
e
      Right a
a -> a
a


-- | Send raw bytes over the connection with a timeout.
send :: (MonadIO m, HasCallStack) => Connection -> Int -> BS.ByteString -> m ()
send :: forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Connection -> Int -> ByteString -> m ()
send Connection
conn Int
timeout ByteString
bytes = do
  let timeout_message :: Text
timeout_message = Text
"sending " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. TextShow a => a -> Text
showt (ByteString -> Int
BS.length ByteString
bytes) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" bytes"
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> Text -> IO () -> IO ()
forall b a. (HasCallStack, TextShow b) => Int -> b -> IO a -> IO a
timeoutThrow Int
timeout Text
timeout_message (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> ByteString -> IO ()
connectionPut Connection
conn ByteString
bytes