{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.Transport.QUIC.Internal.Messaging
  ( -- * Connections
    ServerConnId,
    serverSelfConnId,
    firstNonReservedServerConnId,
    ClientConnId,
    createConnectionId,
    sendMessage,
    receiveMessage,
    MessageReceived (..),

    -- * Specialized messages
    sendAck,
    sendRejection,
    recvAck,
    recvWord32,
    sendCloseConnection,
    sendCloseEndPoint,

    -- * Handshake protocol
    handshake,

    -- * Re-exported for testing
    encodeMessage,
    decodeMessage,
  )
where

import Control.Exception (SomeException, catch, displayException, mask, throwIO, try)
import Control.Monad (replicateM)
import Data.Binary (Binary)
import Data.Binary qualified as Binary
import Data.Bits (shiftL, (.|.))
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.Functor ((<&>))
import Data.Word (Word32, Word8)
import GHC.Exception (Exception)
import Network.QUIC (Stream)
import Network.QUIC qualified as QUIC
import Network.Transport (ConnectionId, EndPointAddress)
import Network.Transport.Internal (decodeWord32, encodeWord32)
import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (QUICAddr), decodeQUICAddr)
import System.Timeout (timeout)

-- | Send a message to a remote endpoint ID
--
-- This function is thread-safe; while the data is sending, asynchronous
-- exceptions are masked, to be rethrown after the data is sent.
sendMessage ::
  Stream ->
  ClientConnId ->
  [ByteString] ->
  IO (Either QUIC.QUICException ())
sendMessage :: Stream
-> ClientConnId -> [ByteString] -> IO (Either QUICException ())
sendMessage Stream
stream ClientConnId
connId [ByteString]
messages =
  IO () -> IO (Either QUICException ())
forall e a. Exception e => IO a -> IO (Either e a)
try
    ( Stream -> [ByteString] -> IO ()
QUIC.sendStreamMany
        Stream
stream
        (ClientConnId -> [ByteString] -> [ByteString]
encodeMessage ClientConnId
connId [ByteString]
messages)
    )

-- | Receive a message, including its local destination endpoint ID
--
-- This function is thread-safe; while the data is being received, asynchronous
-- exceptions are masked, to be rethrown after the data is sent.
receiveMessage ::
  Stream ->
  IO (Either String MessageReceived)
receiveMessage :: Stream -> IO (Either String MessageReceived)
receiveMessage Stream
stream = ((forall a. IO a -> IO a) -> IO (Either String MessageReceived))
-> IO (Either String MessageReceived)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Either String MessageReceived))
 -> IO (Either String MessageReceived))
-> ((forall a. IO a -> IO a) -> IO (Either String MessageReceived))
-> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore ->
  IO (Either String MessageReceived)
-> IO (Either String MessageReceived)
forall a. IO a -> IO a
restore
    ( (Int -> IO ByteString) -> IO (Either String MessageReceived)
decodeMessage
        -- Note that 'recvStream' may return less bytes than requested.
        -- Therefore, we must wrap it in 'getAllBytes'.
        ((Int -> IO ByteString) -> Int -> IO ByteString
getAllBytes (Stream -> Int -> IO ByteString
QUIC.recvStream Stream
stream))
    )
    IO (Either String MessageReceived)
-> (QUICException -> IO (Either String MessageReceived))
-> IO (Either String MessageReceived)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(QUICException
ex :: QUIC.QUICException) -> QUICException -> IO (Either String MessageReceived)
forall e a. Exception e => e -> IO a
throwIO QUICException
ex)

-- | Encode a message.
--
-- The encoding is composed of a header, and the payloads.
-- The message header is composed of:
-- 1. A control byte, to determine how the message should be parsed.
-- 2. A 32-bit word that encodes the endpoint ID of the destination endpoint;
-- 3. A 32-bit word that encodes the number of frames in the message
--
-- The payload frames are each prepended with the length of the frame.
encodeMessage ::
  ClientConnId ->
  [ByteString] ->
  [ByteString]
encodeMessage :: ClientConnId -> [ByteString] -> [ByteString]
encodeMessage ClientConnId
connId [ByteString]
messages =
  [ByteString] -> ByteString
BS.concat
    [ ControlByte -> ByteString
BS.singleton ControlByte
messageControlByte,
      Word32 -> ByteString
encodeWord32 (ClientConnId -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ClientConnId
connId),
      Word32 -> ByteString
encodeWord32 (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ [ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
messages)
    ]
    ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [Word32 -> ByteString
encodeWord32 (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
message) ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
message | ByteString
message <- [ByteString]
messages]

decodeMessage ::
  (Int -> IO ByteString) ->
  IO (Either String MessageReceived)
decodeMessage :: (Int -> IO ByteString) -> IO (Either String MessageReceived)
decodeMessage Int -> IO ByteString
get =
  Int -> IO ByteString
get Int
1
    IO ByteString
-> (ByteString -> IO (Either String MessageReceived))
-> IO (Either String MessageReceived)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO (Either String MessageReceived)
-> (ControlByte -> IO (Either String MessageReceived))
-> Maybe ControlByte
-> IO (Either String MessageReceived)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (Either String MessageReceived -> IO (Either String MessageReceived)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String MessageReceived
 -> IO (Either String MessageReceived))
-> Either String MessageReceived
-> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ MessageReceived -> Either String MessageReceived
forall a b. b -> Either a b
Right MessageReceived
StreamClosed)
      ( \ControlByte
controlByte ->
          ControlByte -> IO (Either String MessageReceived)
go ControlByte
controlByte IO (Either String MessageReceived)
-> (SomeException -> IO (Either String MessageReceived))
-> IO (Either String MessageReceived)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
ex :: SomeException) -> Either String MessageReceived -> IO (Either String MessageReceived)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String MessageReceived
 -> IO (Either String MessageReceived))
-> Either String MessageReceived
-> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ String -> Either String MessageReceived
forall a b. a -> Either a b
Left (SomeException -> String
forall e. Exception e => e -> String
displayException SomeException
ex))
      ) (Maybe ControlByte -> IO (Either String MessageReceived))
-> (ByteString -> Maybe ControlByte)
-> ByteString
-> IO (Either String MessageReceived)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Int -> Maybe ControlByte)
-> Int -> ByteString -> Maybe ControlByte
forall a b c. (a -> b -> c) -> b -> a -> c
flip ByteString -> Int -> Maybe ControlByte
BS.indexMaybe Int
0
  where
    go :: ControlByte -> IO (Either String MessageReceived)
go ControlByte
ctrl
      | ControlByte
ctrl ControlByte -> ControlByte -> Bool
forall a. Eq a => a -> a -> Bool
== ControlByte
closeEndPointControlByte = Either String MessageReceived -> IO (Either String MessageReceived)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String MessageReceived
 -> IO (Either String MessageReceived))
-> Either String MessageReceived
-> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ MessageReceived -> Either String MessageReceived
forall a b. b -> Either a b
Right MessageReceived
CloseEndPoint
      | ControlByte
ctrl ControlByte -> ControlByte -> Bool
forall a. Eq a => a -> a -> Bool
== ControlByte
closeConnectionControlByte = MessageReceived -> Either String MessageReceived
forall a b. b -> Either a b
Right (MessageReceived -> Either String MessageReceived)
-> (Word32 -> MessageReceived)
-> Word32
-> Either String MessageReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ClientConnId -> MessageReceived
CloseConnection (ClientConnId -> MessageReceived)
-> (Word32 -> ClientConnId) -> Word32 -> MessageReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> ClientConnId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Either String MessageReceived)
-> IO Word32 -> IO (Either String MessageReceived)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Word32
getWord32
      | ControlByte
ctrl ControlByte -> ControlByte -> Bool
forall a. Eq a => a -> a -> Bool
== ControlByte
messageControlByte = do
          Word32
connId <- IO Word32
getWord32
          Word32
numMessages <- IO Word32
getWord32
          [ByteString]
messages <- Int -> IO ByteString -> IO [ByteString]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
numMessages) (IO ByteString -> IO [ByteString])
-> IO ByteString -> IO [ByteString]
forall a b. (a -> b) -> a -> b
$ do
            IO Word32
getWord32 IO Word32 -> (Word32 -> IO ByteString) -> IO ByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> IO ByteString
get (Int -> IO ByteString)
-> (Word32 -> Int) -> Word32 -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
          Either String MessageReceived -> IO (Either String MessageReceived)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String MessageReceived
 -> IO (Either String MessageReceived))
-> (MessageReceived -> Either String MessageReceived)
-> MessageReceived
-> IO (Either String MessageReceived)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MessageReceived -> Either String MessageReceived
forall a b. b -> Either a b
Right (MessageReceived -> IO (Either String MessageReceived))
-> MessageReceived -> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ ClientConnId -> [ByteString] -> MessageReceived
Message (Word32 -> ClientConnId
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
connId) [ByteString]
messages
      | Bool
otherwise = Either String MessageReceived -> IO (Either String MessageReceived)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String MessageReceived
 -> IO (Either String MessageReceived))
-> Either String MessageReceived
-> IO (Either String MessageReceived)
forall a b. (a -> b) -> a -> b
$ String -> Either String MessageReceived
forall a b. a -> Either a b
Left (String -> Either String MessageReceived)
-> String -> Either String MessageReceived
forall a b. (a -> b) -> a -> b
$ String
"Unsupported control byte: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> ControlByte -> String
forall a. Show a => a -> String
show ControlByte
ctrl
    getWord32 :: IO Word32
getWord32 = Int -> IO ByteString
get Int
4 IO ByteString -> (ByteString -> Word32) -> IO Word32
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> ByteString -> Word32
decodeWord32

-- | Wrap a method to fetch bytes, to ensure that we always get exactly the
-- intended number of bytes.
getAllBytes ::
  -- | Function to fetch at most 'n' bytes
  (Int -> IO ByteString) ->
  -- | Function to fetch exactly 'n' bytes
  (Int -> IO ByteString)
getAllBytes :: (Int -> IO ByteString) -> Int -> IO ByteString
getAllBytes Int -> IO ByteString
get Int
n = Int -> [ByteString] -> IO ByteString
go Int
n [ByteString]
forall a. Monoid a => a
mempty
  where
    go :: Int -> [ByteString] -> IO ByteString
go Int
0 ![ByteString]
acc = ByteString -> IO ByteString
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> IO ByteString) -> ByteString -> IO ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BS.concat [ByteString]
acc
    go Int
m ![ByteString]
acc =
      Int -> IO ByteString
get Int
m IO ByteString -> (ByteString -> IO ByteString) -> IO ByteString
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ByteString
bytes ->
        Int -> [ByteString] -> IO ByteString
go
          (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
BS.length ByteString
bytes)
          ([ByteString]
acc [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> [ByteString
bytes])

data MessageReceived
  = Message
      {-# UNPACK #-} !ClientConnId
      {-# UNPACK #-} ![ByteString]
  | CloseConnection !ClientConnId
  | CloseEndPoint
  | StreamClosed
  deriving (Int -> MessageReceived -> String -> String
[MessageReceived] -> String -> String
MessageReceived -> String
(Int -> MessageReceived -> String -> String)
-> (MessageReceived -> String)
-> ([MessageReceived] -> String -> String)
-> Show MessageReceived
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> MessageReceived -> String -> String
showsPrec :: Int -> MessageReceived -> String -> String
$cshow :: MessageReceived -> String
show :: MessageReceived -> String
$cshowList :: [MessageReceived] -> String -> String
showList :: [MessageReceived] -> String -> String
Show, MessageReceived -> MessageReceived -> Bool
(MessageReceived -> MessageReceived -> Bool)
-> (MessageReceived -> MessageReceived -> Bool)
-> Eq MessageReceived
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MessageReceived -> MessageReceived -> Bool
== :: MessageReceived -> MessageReceived -> Bool
$c/= :: MessageReceived -> MessageReceived -> Bool
/= :: MessageReceived -> MessageReceived -> Bool
Eq)

newtype AckException = AckException String
  deriving (Int -> AckException -> String -> String
[AckException] -> String -> String
AckException -> String
(Int -> AckException -> String -> String)
-> (AckException -> String)
-> ([AckException] -> String -> String)
-> Show AckException
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> AckException -> String -> String
showsPrec :: Int -> AckException -> String -> String
$cshow :: AckException -> String
show :: AckException -> String
$cshowList :: [AckException] -> String -> String
showList :: [AckException] -> String -> String
Show, AckException -> AckException -> Bool
(AckException -> AckException -> Bool)
-> (AckException -> AckException -> Bool) -> Eq AckException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: AckException -> AckException -> Bool
== :: AckException -> AckException -> Bool
$c/= :: AckException -> AckException -> Bool
/= :: AckException -> AckException -> Bool
Eq)

instance Exception AckException

ackMessage :: ByteString
ackMessage :: ByteString
ackMessage = ControlByte -> ByteString
BS.singleton ControlByte
connectionAcceptedControlByte

rejectMessage :: ByteString
rejectMessage :: ByteString
rejectMessage = ControlByte -> ByteString
BS.singleton ControlByte
connectionRejectedControlByte

sendAck :: Stream -> IO ()
sendAck :: Stream -> IO ()
sendAck =
  (Stream -> ByteString -> IO ()) -> ByteString -> Stream -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip
    Stream -> ByteString -> IO ()
QUIC.sendStream
    ByteString
ackMessage

sendRejection :: Stream -> IO ()
sendRejection :: Stream -> IO ()
sendRejection =
  (Stream -> ByteString -> IO ()) -> ByteString -> Stream -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip
    Stream -> ByteString -> IO ()
QUIC.sendStream
    ByteString
rejectMessage

recvAck :: Stream -> IO (Either () ())
recvAck :: Stream -> IO (Either () ())
recvAck Stream
stream = do
  -- TODO: make timeout configurable
  Int -> IO ByteString -> IO (Maybe ByteString)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
500_000 (Stream -> Int -> IO ByteString
QUIC.recvStream Stream
stream Int
1)
    IO (Maybe ByteString)
-> (Maybe ByteString -> IO (Either () ())) -> IO (Either () ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO (Either () ())
-> (ByteString -> IO (Either () ()))
-> Maybe ByteString
-> IO (Either () ())
forall b a. b -> (a -> b) -> Maybe a -> b
maybe
      (AckException -> IO (Either () ())
forall e a. Exception e => e -> IO a
throwIO (String -> AckException
AckException String
"Connection ack not received within acceptable timeframe"))
      ByteString -> IO (Either () ())
go
  where
    go :: ByteString -> IO (Either () ())
go ByteString
response
      | ByteString
response ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
ackMessage = Either () () -> IO (Either () ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either () () -> IO (Either () ()))
-> Either () () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ () -> Either () ()
forall a b. b -> Either a b
Right ()
      | ByteString
response ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
rejectMessage = Either () () -> IO (Either () ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either () () -> IO (Either () ()))
-> Either () () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ () -> Either () ()
forall a b. a -> Either a b
Left ()
      | Bool
otherwise = AckException -> IO (Either () ())
forall e a. Exception e => e -> IO a
throwIO (String -> AckException
AckException String
"Unexpected ack response")

-- | Receive a 'Word32'
--
-- This function is thread-safe; while the data is being received, asynchronous
-- exceptions are masked, to be rethrown after the data is received.
recvWord32 ::
  Stream ->
  IO (Either String Word32)
recvWord32 :: Stream -> IO (Either String Word32)
recvWord32 Stream
stream =
  ((forall a. IO a -> IO a) -> IO (Either String Word32))
-> IO (Either String Word32)
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO (Either String Word32))
 -> IO (Either String Word32))
-> ((forall a. IO a -> IO a) -> IO (Either String Word32))
-> IO (Either String Word32)
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore ->
    IO (Either String Word32) -> IO (Either String Word32)
forall a. IO a -> IO a
restore
      ( Stream -> Int -> IO ByteString
QUIC.recvStream Stream
stream Int
4 IO ByteString
-> (ByteString -> Either String Word32)
-> IO (Either String Word32)
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> Word32 -> Either String Word32
forall a b. b -> Either a b
Right (Word32 -> Either String Word32)
-> (ByteString -> Word32) -> ByteString -> Either String Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word32
decodeWord32
      )
      IO (Either String Word32)
-> (SomeException -> IO (Either String Word32))
-> IO (Either String Word32)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\(SomeException
ex :: SomeException) -> Either String Word32 -> IO (Either String Word32)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either String Word32 -> IO (Either String Word32))
-> Either String Word32 -> IO (Either String Word32)
forall a b. (a -> b) -> a -> b
$ String -> Either String Word32
forall a b. a -> Either a b
Left (SomeException -> String
forall e. Exception e => e -> String
displayException SomeException
ex))

-- | We perform some special actions based on a message's control byte.
-- For example, if a client wants to close a connection.
type ControlByte = Word8

connectionAcceptedControlByte :: ControlByte
connectionAcceptedControlByte :: ControlByte
connectionAcceptedControlByte = ControlByte
0

connectionRejectedControlByte :: ControlByte
connectionRejectedControlByte :: ControlByte
connectionRejectedControlByte = ControlByte
1

messageControlByte :: ControlByte
messageControlByte :: ControlByte
messageControlByte = ControlByte
2

closeEndPointControlByte :: ControlByte
closeEndPointControlByte :: ControlByte
closeEndPointControlByte = ControlByte
127

closeConnectionControlByte :: ControlByte
closeConnectionControlByte :: ControlByte
closeConnectionControlByte = ControlByte
255

-- | Send a message to close the connection.
sendCloseConnection :: ClientConnId -> Stream -> IO (Either QUIC.QUICException ())
sendCloseConnection :: ClientConnId -> Stream -> IO (Either QUICException ())
sendCloseConnection ClientConnId
connId Stream
stream =
  IO () -> IO (Either QUICException ())
forall e a. Exception e => IO a -> IO (Either e a)
try
    ( Stream -> ByteString -> IO ()
QUIC.sendStream
        Stream
stream
        ( [ByteString] -> ByteString
BS.concat [ControlByte -> ByteString
BS.singleton ControlByte
closeConnectionControlByte, Word32 -> ByteString
encodeWord32 (ClientConnId -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ClientConnId
connId)]
        )
    )

-- | Send a message to close the connection.
sendCloseEndPoint :: Stream -> IO (Either QUIC.QUICException ())
sendCloseEndPoint :: Stream -> IO (Either QUICException ())
sendCloseEndPoint Stream
stream =
  IO () -> IO (Either QUICException ())
forall e a. Exception e => IO a -> IO (Either e a)
try
    ( Stream -> ByteString -> IO ()
QUIC.sendStream
        Stream
stream
        ( ControlByte -> ByteString
BS.singleton ControlByte
closeEndPointControlByte
        )
    )

-- | Handshake protocol that a client, connecting to a remote endpoint,
-- has to perform.
-- TODO: encode server part of the handhake
handshake ::
  (EndPointAddress, EndPointAddress) ->
  Stream ->
  IO (Either () ())
handshake :: (EndPointAddress, EndPointAddress) -> Stream -> IO (Either () ())
handshake (EndPointAddress
ourAddress, EndPointAddress
theirAddress) Stream
stream =
  case EndPointAddress -> Either String QUICAddr
decodeQUICAddr EndPointAddress
theirAddress of
    Left String
errmsg -> IOError -> IO (Either () ())
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO (Either () ())) -> IOError -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError (String
"Could not decode QUIC address: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
errmsg)
    Right (QUICAddr String
_ String
_ EndPointId
serverEndPointId) -> do
      -- Handshake on connection creation, which simply involves
      -- sending our address over, and
      -- the endpoint ID of the endpoint we want to communicate with
      let encodedPayload :: ByteString
encodedPayload = LazyByteString -> ByteString
BS.toStrict (LazyByteString -> ByteString) -> LazyByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ (EndPointAddress, EndPointId) -> LazyByteString
forall a. Binary a => a -> LazyByteString
Binary.encode (EndPointAddress
ourAddress, EndPointId
serverEndPointId)
          payloadLength :: ByteString
payloadLength = Word32 -> ByteString
encodeWord32 (Word32 -> ByteString) -> Word32 -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
encodedPayload)

      IO () -> IO (Either SomeException ())
forall e a. Exception e => IO a -> IO (Either e a)
try
        ( Stream -> ByteString -> IO ()
QUIC.sendStream
            Stream
stream
            ([ByteString] -> ByteString
BS.concat [ByteString
payloadLength, ByteString
encodedPayload])
        )
        IO (Either SomeException ())
-> (Either SomeException () -> IO (Either () ()))
-> IO (Either () ())
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
          Left (SomeException
_exc :: SomeException) -> Either () () -> IO (Either () ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either () () -> IO (Either () ()))
-> Either () () -> IO (Either () ())
forall a b. (a -> b) -> a -> b
$ () -> Either () ()
forall a b. a -> Either a b
Left ()
          Right ()
_ ->
            -- Server acknowledgement that the handshake is complete
            -- means that we cannot send messages until the server
            -- is ready for them
            Stream -> IO (Either () ())
recvAck Stream
stream

-- | Part of the connection ID that is client-allocated.
newtype ClientConnId = ClientConnId Word32
  deriving newtype (ClientConnId -> ClientConnId -> Bool
(ClientConnId -> ClientConnId -> Bool)
-> (ClientConnId -> ClientConnId -> Bool) -> Eq ClientConnId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ClientConnId -> ClientConnId -> Bool
== :: ClientConnId -> ClientConnId -> Bool
$c/= :: ClientConnId -> ClientConnId -> Bool
/= :: ClientConnId -> ClientConnId -> Bool
Eq, Int -> ClientConnId -> String -> String
[ClientConnId] -> String -> String
ClientConnId -> String
(Int -> ClientConnId -> String -> String)
-> (ClientConnId -> String)
-> ([ClientConnId] -> String -> String)
-> Show ClientConnId
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> ClientConnId -> String -> String
showsPrec :: Int -> ClientConnId -> String -> String
$cshow :: ClientConnId -> String
show :: ClientConnId -> String
$cshowList :: [ClientConnId] -> String -> String
showList :: [ClientConnId] -> String -> String
Show, Eq ClientConnId
Eq ClientConnId =>
(ClientConnId -> ClientConnId -> Ordering)
-> (ClientConnId -> ClientConnId -> Bool)
-> (ClientConnId -> ClientConnId -> Bool)
-> (ClientConnId -> ClientConnId -> Bool)
-> (ClientConnId -> ClientConnId -> Bool)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> Ord ClientConnId
ClientConnId -> ClientConnId -> Bool
ClientConnId -> ClientConnId -> Ordering
ClientConnId -> ClientConnId -> ClientConnId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ClientConnId -> ClientConnId -> Ordering
compare :: ClientConnId -> ClientConnId -> Ordering
$c< :: ClientConnId -> ClientConnId -> Bool
< :: ClientConnId -> ClientConnId -> Bool
$c<= :: ClientConnId -> ClientConnId -> Bool
<= :: ClientConnId -> ClientConnId -> Bool
$c> :: ClientConnId -> ClientConnId -> Bool
> :: ClientConnId -> ClientConnId -> Bool
$c>= :: ClientConnId -> ClientConnId -> Bool
>= :: ClientConnId -> ClientConnId -> Bool
$cmax :: ClientConnId -> ClientConnId -> ClientConnId
max :: ClientConnId -> ClientConnId -> ClientConnId
$cmin :: ClientConnId -> ClientConnId -> ClientConnId
min :: ClientConnId -> ClientConnId -> ClientConnId
Ord, ClientConnId
ClientConnId -> ClientConnId -> Bounded ClientConnId
forall a. a -> a -> Bounded a
$cminBound :: ClientConnId
minBound :: ClientConnId
$cmaxBound :: ClientConnId
maxBound :: ClientConnId
Bounded, Int -> ClientConnId
ClientConnId -> Int
ClientConnId -> [ClientConnId]
ClientConnId -> ClientConnId
ClientConnId -> ClientConnId -> [ClientConnId]
ClientConnId -> ClientConnId -> ClientConnId -> [ClientConnId]
(ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId)
-> (Int -> ClientConnId)
-> (ClientConnId -> Int)
-> (ClientConnId -> [ClientConnId])
-> (ClientConnId -> ClientConnId -> [ClientConnId])
-> (ClientConnId -> ClientConnId -> [ClientConnId])
-> (ClientConnId -> ClientConnId -> ClientConnId -> [ClientConnId])
-> Enum ClientConnId
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: ClientConnId -> ClientConnId
succ :: ClientConnId -> ClientConnId
$cpred :: ClientConnId -> ClientConnId
pred :: ClientConnId -> ClientConnId
$ctoEnum :: Int -> ClientConnId
toEnum :: Int -> ClientConnId
$cfromEnum :: ClientConnId -> Int
fromEnum :: ClientConnId -> Int
$cenumFrom :: ClientConnId -> [ClientConnId]
enumFrom :: ClientConnId -> [ClientConnId]
$cenumFromThen :: ClientConnId -> ClientConnId -> [ClientConnId]
enumFromThen :: ClientConnId -> ClientConnId -> [ClientConnId]
$cenumFromTo :: ClientConnId -> ClientConnId -> [ClientConnId]
enumFromTo :: ClientConnId -> ClientConnId -> [ClientConnId]
$cenumFromThenTo :: ClientConnId -> ClientConnId -> ClientConnId -> [ClientConnId]
enumFromThenTo :: ClientConnId -> ClientConnId -> ClientConnId -> [ClientConnId]
Enum, Num ClientConnId
Ord ClientConnId
(Num ClientConnId, Ord ClientConnId) =>
(ClientConnId -> Rational) -> Real ClientConnId
ClientConnId -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: ClientConnId -> Rational
toRational :: ClientConnId -> Rational
Real, Enum ClientConnId
Real ClientConnId
(Real ClientConnId, Enum ClientConnId) =>
(ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId))
-> (ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId))
-> (ClientConnId -> Integer)
-> Integral ClientConnId
ClientConnId -> Integer
ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId)
ClientConnId -> ClientConnId -> ClientConnId
forall a.
(Real a, Enum a) =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> (a, a))
-> (a -> a -> (a, a))
-> (a -> Integer)
-> Integral a
$cquot :: ClientConnId -> ClientConnId -> ClientConnId
quot :: ClientConnId -> ClientConnId -> ClientConnId
$crem :: ClientConnId -> ClientConnId -> ClientConnId
rem :: ClientConnId -> ClientConnId -> ClientConnId
$cdiv :: ClientConnId -> ClientConnId -> ClientConnId
div :: ClientConnId -> ClientConnId -> ClientConnId
$cmod :: ClientConnId -> ClientConnId -> ClientConnId
mod :: ClientConnId -> ClientConnId -> ClientConnId
$cquotRem :: ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId)
quotRem :: ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId)
$cdivMod :: ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId)
divMod :: ClientConnId -> ClientConnId -> (ClientConnId, ClientConnId)
$ctoInteger :: ClientConnId -> Integer
toInteger :: ClientConnId -> Integer
Integral, Integer -> ClientConnId
ClientConnId -> ClientConnId
ClientConnId -> ClientConnId -> ClientConnId
(ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId)
-> (ClientConnId -> ClientConnId)
-> (Integer -> ClientConnId)
-> Num ClientConnId
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: ClientConnId -> ClientConnId -> ClientConnId
+ :: ClientConnId -> ClientConnId -> ClientConnId
$c- :: ClientConnId -> ClientConnId -> ClientConnId
- :: ClientConnId -> ClientConnId -> ClientConnId
$c* :: ClientConnId -> ClientConnId -> ClientConnId
* :: ClientConnId -> ClientConnId -> ClientConnId
$cnegate :: ClientConnId -> ClientConnId
negate :: ClientConnId -> ClientConnId
$cabs :: ClientConnId -> ClientConnId
abs :: ClientConnId -> ClientConnId
$csignum :: ClientConnId -> ClientConnId
signum :: ClientConnId -> ClientConnId
$cfromInteger :: Integer -> ClientConnId
fromInteger :: Integer -> ClientConnId
Num, Get ClientConnId
[ClientConnId] -> Put
ClientConnId -> Put
(ClientConnId -> Put)
-> Get ClientConnId
-> ([ClientConnId] -> Put)
-> Binary ClientConnId
forall t. (t -> Put) -> Get t -> ([t] -> Put) -> Binary t
$cput :: ClientConnId -> Put
put :: ClientConnId -> Put
$cget :: Get ClientConnId
get :: Get ClientConnId
$cputList :: [ClientConnId] -> Put
putList :: [ClientConnId] -> Put
Binary)

-- | Part of the connection ID that is server-allocated.
newtype ServerConnId = ServerConnId Word32
  deriving newtype (ServerConnId -> ServerConnId -> Bool
(ServerConnId -> ServerConnId -> Bool)
-> (ServerConnId -> ServerConnId -> Bool) -> Eq ServerConnId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ServerConnId -> ServerConnId -> Bool
== :: ServerConnId -> ServerConnId -> Bool
$c/= :: ServerConnId -> ServerConnId -> Bool
/= :: ServerConnId -> ServerConnId -> Bool
Eq, Int -> ServerConnId -> String -> String
[ServerConnId] -> String -> String
ServerConnId -> String
(Int -> ServerConnId -> String -> String)
-> (ServerConnId -> String)
-> ([ServerConnId] -> String -> String)
-> Show ServerConnId
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> ServerConnId -> String -> String
showsPrec :: Int -> ServerConnId -> String -> String
$cshow :: ServerConnId -> String
show :: ServerConnId -> String
$cshowList :: [ServerConnId] -> String -> String
showList :: [ServerConnId] -> String -> String
Show, Eq ServerConnId
Eq ServerConnId =>
(ServerConnId -> ServerConnId -> Ordering)
-> (ServerConnId -> ServerConnId -> Bool)
-> (ServerConnId -> ServerConnId -> Bool)
-> (ServerConnId -> ServerConnId -> Bool)
-> (ServerConnId -> ServerConnId -> Bool)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> Ord ServerConnId
ServerConnId -> ServerConnId -> Bool
ServerConnId -> ServerConnId -> Ordering
ServerConnId -> ServerConnId -> ServerConnId
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ServerConnId -> ServerConnId -> Ordering
compare :: ServerConnId -> ServerConnId -> Ordering
$c< :: ServerConnId -> ServerConnId -> Bool
< :: ServerConnId -> ServerConnId -> Bool
$c<= :: ServerConnId -> ServerConnId -> Bool
<= :: ServerConnId -> ServerConnId -> Bool
$c> :: ServerConnId -> ServerConnId -> Bool
> :: ServerConnId -> ServerConnId -> Bool
$c>= :: ServerConnId -> ServerConnId -> Bool
>= :: ServerConnId -> ServerConnId -> Bool
$cmax :: ServerConnId -> ServerConnId -> ServerConnId
max :: ServerConnId -> ServerConnId -> ServerConnId
$cmin :: ServerConnId -> ServerConnId -> ServerConnId
min :: ServerConnId -> ServerConnId -> ServerConnId
Ord, ServerConnId
ServerConnId -> ServerConnId -> Bounded ServerConnId
forall a. a -> a -> Bounded a
$cminBound :: ServerConnId
minBound :: ServerConnId
$cmaxBound :: ServerConnId
maxBound :: ServerConnId
Bounded, Int -> ServerConnId
ServerConnId -> Int
ServerConnId -> [ServerConnId]
ServerConnId -> ServerConnId
ServerConnId -> ServerConnId -> [ServerConnId]
ServerConnId -> ServerConnId -> ServerConnId -> [ServerConnId]
(ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId)
-> (Int -> ServerConnId)
-> (ServerConnId -> Int)
-> (ServerConnId -> [ServerConnId])
-> (ServerConnId -> ServerConnId -> [ServerConnId])
-> (ServerConnId -> ServerConnId -> [ServerConnId])
-> (ServerConnId -> ServerConnId -> ServerConnId -> [ServerConnId])
-> Enum ServerConnId
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: ServerConnId -> ServerConnId
succ :: ServerConnId -> ServerConnId
$cpred :: ServerConnId -> ServerConnId
pred :: ServerConnId -> ServerConnId
$ctoEnum :: Int -> ServerConnId
toEnum :: Int -> ServerConnId
$cfromEnum :: ServerConnId -> Int
fromEnum :: ServerConnId -> Int
$cenumFrom :: ServerConnId -> [ServerConnId]
enumFrom :: ServerConnId -> [ServerConnId]
$cenumFromThen :: ServerConnId -> ServerConnId -> [ServerConnId]
enumFromThen :: ServerConnId -> ServerConnId -> [ServerConnId]
$cenumFromTo :: ServerConnId -> ServerConnId -> [ServerConnId]
enumFromTo :: ServerConnId -> ServerConnId -> [ServerConnId]
$cenumFromThenTo :: ServerConnId -> ServerConnId -> ServerConnId -> [ServerConnId]
enumFromThenTo :: ServerConnId -> ServerConnId -> ServerConnId -> [ServerConnId]
Enum, Num ServerConnId
Ord ServerConnId
(Num ServerConnId, Ord ServerConnId) =>
(ServerConnId -> Rational) -> Real ServerConnId
ServerConnId -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: ServerConnId -> Rational
toRational :: ServerConnId -> Rational
Real, Enum ServerConnId
Real ServerConnId
(Real ServerConnId, Enum ServerConnId) =>
(ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId))
-> (ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId))
-> (ServerConnId -> Integer)
-> Integral ServerConnId
ServerConnId -> Integer
ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId)
ServerConnId -> ServerConnId -> ServerConnId
forall a.
(Real a, Enum a) =>
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> (a, a))
-> (a -> a -> (a, a))
-> (a -> Integer)
-> Integral a
$cquot :: ServerConnId -> ServerConnId -> ServerConnId
quot :: ServerConnId -> ServerConnId -> ServerConnId
$crem :: ServerConnId -> ServerConnId -> ServerConnId
rem :: ServerConnId -> ServerConnId -> ServerConnId
$cdiv :: ServerConnId -> ServerConnId -> ServerConnId
div :: ServerConnId -> ServerConnId -> ServerConnId
$cmod :: ServerConnId -> ServerConnId -> ServerConnId
mod :: ServerConnId -> ServerConnId -> ServerConnId
$cquotRem :: ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId)
quotRem :: ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId)
$cdivMod :: ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId)
divMod :: ServerConnId -> ServerConnId -> (ServerConnId, ServerConnId)
$ctoInteger :: ServerConnId -> Integer
toInteger :: ServerConnId -> Integer
Integral, Integer -> ServerConnId
ServerConnId -> ServerConnId
ServerConnId -> ServerConnId -> ServerConnId
(ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId)
-> (ServerConnId -> ServerConnId)
-> (Integer -> ServerConnId)
-> Num ServerConnId
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: ServerConnId -> ServerConnId -> ServerConnId
+ :: ServerConnId -> ServerConnId -> ServerConnId
$c- :: ServerConnId -> ServerConnId -> ServerConnId
- :: ServerConnId -> ServerConnId -> ServerConnId
$c* :: ServerConnId -> ServerConnId -> ServerConnId
* :: ServerConnId -> ServerConnId -> ServerConnId
$cnegate :: ServerConnId -> ServerConnId
negate :: ServerConnId -> ServerConnId
$cabs :: ServerConnId -> ServerConnId
abs :: ServerConnId -> ServerConnId
$csignum :: ServerConnId -> ServerConnId
signum :: ServerConnId -> ServerConnId
$cfromInteger :: Integer -> ServerConnId
fromInteger :: Integer -> ServerConnId
Num)

-- | Self-connection
serverSelfConnId :: ServerConnId
serverSelfConnId :: ServerConnId
serverSelfConnId = ServerConnId
0

-- | We reserve some connection IDs for special heavyweight connections
firstNonReservedServerConnId :: ServerConnId
firstNonReservedServerConnId :: ServerConnId
firstNonReservedServerConnId = ServerConnId
1

-- | Construct a ConnectionId
createConnectionId ::
  ServerConnId ->
  ClientConnId ->
  ConnectionId
createConnectionId :: ServerConnId -> ClientConnId -> ConnectionId
createConnectionId ServerConnId
sid ClientConnId
cid =
  (ServerConnId -> ConnectionId
forall a b. (Integral a, Num b) => a -> b
fromIntegral ServerConnId
sid ConnectionId -> Int -> ConnectionId
forall a. Bits a => a -> Int -> a
`shiftL` Int
32) ConnectionId -> ConnectionId -> ConnectionId
forall a. Bits a => a -> a -> a
.|. ClientConnId -> ConnectionId
forall a b. (Integral a, Num b) => a -> b
fromIntegral ClientConnId
cid