{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Network.Transport.QUIC.Internal.Client (
streamToEndpoint,
)
where
import Control.Concurrent (forkIOWithUnmask, newEmptyMVar)
import Control.Concurrent.Async (withAsync)
import Control.Concurrent.MVar (MVar, putMVar, takeMVar)
import Control.Exception (SomeException, bracket, catch, mask, mask_, throwIO)
import Data.List.NonEmpty (NonEmpty)
import Network.QUIC qualified as QUIC
import Network.QUIC.Client qualified as QUIC.Client
import Network.Transport (ConnectErrorCode (ConnectNotFound), EndPointAddress, TransportError (..))
import Network.Transport.QUIC.Internal.Configuration (Credential, mkClientConfig)
import Network.Transport.QUIC.Internal.Messaging (MessageReceived (..), handshake, receiveMessage)
import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (QUICAddr), decodeQUICAddr)
streamToEndpoint ::
NonEmpty Credential ->
Bool ->
EndPointAddress ->
EndPointAddress ->
(SomeException -> IO ()) ->
IO () ->
IO
( Either
(TransportError ConnectErrorCode)
( MVar ()
,
QUIC.Stream
)
)
streamToEndpoint :: NonEmpty Credential
-> Bool
-> EndPointAddress
-> EndPointAddress
-> (SomeException -> IO ())
-> IO ()
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
streamToEndpoint NonEmpty Credential
creds Bool
validateCreds EndPointAddress
ourAddress EndPointAddress
theirAddress SomeException -> IO ()
onExc IO ()
onCloseForcibly =
case EndPointAddress -> Either String QUICAddr
decodeQUICAddr EndPointAddress
theirAddress of
Left String
errmsg -> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream)))
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a b. (a -> b) -> a -> b
$ TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
forall a b. a -> Either a b
Left (ConnectErrorCode -> String -> TransportError ConnectErrorCode
forall error. error -> String -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound String
errmsg)
Right (QUICAddr String
hostname String
servicename EndPointId
_) -> do
ClientConfig
clientConfig <- String -> String -> NonEmpty Credential -> Bool -> IO ClientConfig
mkClientConfig String
hostname String
servicename NonEmpty Credential
creds Bool
validateCreds
MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar <- IO (MVar (Either (TransportError ConnectErrorCode) Stream))
forall a. IO (MVar a)
newEmptyMVar
MVar ()
doneMVar <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
let runClient :: QUIC.Connection -> IO ()
runClient :: Connection -> IO ()
runClient Connection
conn = ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
Connection -> IO ()
QUIC.waitEstablished Connection
conn
IO () -> IO ()
forall a. IO a -> IO a
restore (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
IO Stream -> (Stream -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Connection -> IO Stream
QUIC.stream Connection
conn) Stream -> IO ()
QUIC.closeStream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream -> do
(EndPointAddress, EndPointAddress) -> Stream -> IO (Either () ())
handshake (EndPointAddress
ourAddress, EndPointAddress
theirAddress) Stream
stream
IO (Either () ()) -> (Either () () -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (() -> IO ()) -> (() -> IO ()) -> Either () () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
(\()
_ -> MVar (Either (TransportError ConnectErrorCode) Stream)
-> Either (TransportError ConnectErrorCode) Stream -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar (TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) Stream
forall a b. a -> Either a b
Left (TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) Stream)
-> TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) Stream
forall a b. (a -> b) -> a -> b
$ ConnectErrorCode -> String -> TransportError ConnectErrorCode
forall error. error -> String -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound String
"handshake failed"))
(\()
_ -> MVar (Either (TransportError ConnectErrorCode) Stream)
-> Either (TransportError ConnectErrorCode) Stream -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar (Stream -> Either (TransportError ConnectErrorCode) Stream
forall a b. b -> Either a b
Right Stream
stream))
IO () -> (Async () -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Stream -> MVar () -> IO ()
listenForClose Stream
stream MVar ()
doneMVar) ((Async () -> IO ()) -> IO ()) -> (Async () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async ()
_ ->
MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
doneMVar
ThreadId
_ <- IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
\forall a. IO a -> IO a
unmask ->
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
( IO () -> IO ()
forall a. IO a -> IO a
unmask (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
ClientConfig -> (Connection -> IO ()) -> IO ()
forall a. ClientConfig -> (Connection -> IO a) -> IO a
QUIC.Client.run
ClientConfig
clientConfig
( \Connection
conn ->
IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
(Connection -> IO ()
runClient Connection
conn)
(forall e a. Exception e => e -> IO a
throwIO @SomeException)
)
)
SomeException -> IO ()
onExc
Either (TransportError ConnectErrorCode) Stream
streamOrError <- MVar (Either (TransportError ConnectErrorCode) Stream)
-> IO (Either (TransportError ConnectErrorCode) Stream)
forall a. MVar a -> IO a
takeMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar
Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream)))
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a b. (a -> b) -> a -> b
$ (MVar ()
doneMVar,) (Stream -> (MVar (), Stream))
-> Either (TransportError ConnectErrorCode) Stream
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either (TransportError ConnectErrorCode) Stream
streamOrError
where
listenForClose :: QUIC.Stream -> MVar () -> IO ()
listenForClose :: Stream -> MVar () -> IO ()
listenForClose Stream
stream MVar ()
doneMVar =
Stream -> IO (Either String MessageReceived)
receiveMessage Stream
stream
IO (Either String MessageReceived)
-> (Either String MessageReceived -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Right MessageReceived
StreamClosed -> do
MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
Right (CloseConnection ClientConnId
_) -> do
MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
Right MessageReceived
CloseEndPoint -> do
MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
IO ()
onCloseForcibly
Either String MessageReceived
other -> IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> (String -> IOError) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IOError
userError (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Unexpected incoming message to client: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Either String MessageReceived -> String
forall a. Show a => a -> String
show Either String MessageReceived
other