{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}

module Network.Transport.QUIC.Internal.Client (
  streamToEndpoint,
)
where

import Control.Concurrent (forkIOWithUnmask, newEmptyMVar)
import Control.Concurrent.Async (withAsync)
import Control.Concurrent.MVar (MVar, putMVar, takeMVar)
import Control.Exception (SomeException, bracket, catch, mask, mask_, throwIO)
import Data.List.NonEmpty (NonEmpty)
import Network.QUIC qualified as QUIC
import Network.QUIC.Client qualified as QUIC.Client
import Network.Transport (ConnectErrorCode (ConnectNotFound), EndPointAddress, TransportError (..))
import Network.Transport.QUIC.Internal.Configuration (Credential, mkClientConfig)
import Network.Transport.QUIC.Internal.Messaging (MessageReceived (..), handshake, receiveMessage)
import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (QUICAddr), decodeQUICAddr)

streamToEndpoint ::
  NonEmpty Credential ->
  -- | Validate credentials
  Bool ->
  -- | Our address
  EndPointAddress ->
  -- | Their address
  EndPointAddress ->
  -- | On exception
  (SomeException -> IO ()) ->
  -- | On a message to forcibly close the connection
  IO () ->
  IO
    ( Either
        (TransportError ConnectErrorCode)
        ( MVar ()
        , -- \^ put '()' to close the stream
          QUIC.Stream
        )
    )
streamToEndpoint :: NonEmpty Credential
-> Bool
-> EndPointAddress
-> EndPointAddress
-> (SomeException -> IO ())
-> IO ()
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
streamToEndpoint NonEmpty Credential
creds Bool
validateCreds EndPointAddress
ourAddress EndPointAddress
theirAddress SomeException -> IO ()
onExc IO ()
onCloseForcibly =
  case EndPointAddress -> Either String QUICAddr
decodeQUICAddr EndPointAddress
theirAddress of
    Left String
errmsg -> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (TransportError ConnectErrorCode) (MVar (), Stream)
 -> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream)))
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a b. (a -> b) -> a -> b
$ TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
forall a b. a -> Either a b
Left (ConnectErrorCode -> String -> TransportError ConnectErrorCode
forall error. error -> String -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound String
errmsg)
    Right (QUICAddr String
hostname String
servicename EndPointId
_) -> do
      ClientConfig
clientConfig <- String -> String -> NonEmpty Credential -> Bool -> IO ClientConfig
mkClientConfig String
hostname String
servicename NonEmpty Credential
creds Bool
validateCreds

      MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar <- IO (MVar (Either (TransportError ConnectErrorCode) Stream))
forall a. IO (MVar a)
newEmptyMVar
      MVar ()
doneMVar <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar

      let runClient :: QUIC.Connection -> IO ()
          runClient :: Connection -> IO ()
runClient Connection
conn = ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
            Connection -> IO ()
QUIC.waitEstablished Connection
conn
            IO () -> IO ()
forall a. IO a -> IO a
restore (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              IO Stream -> (Stream -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Connection -> IO Stream
QUIC.stream Connection
conn) Stream -> IO ()
QUIC.closeStream ((Stream -> IO ()) -> IO ()) -> (Stream -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Stream
stream -> do
                (EndPointAddress, EndPointAddress) -> Stream -> IO (Either () ())
handshake (EndPointAddress
ourAddress, EndPointAddress
theirAddress) Stream
stream
                  IO (Either () ()) -> (Either () () -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (() -> IO ()) -> (() -> IO ()) -> Either () () -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
                    (\()
_ -> MVar (Either (TransportError ConnectErrorCode) Stream)
-> Either (TransportError ConnectErrorCode) Stream -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar (TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) Stream
forall a b. a -> Either a b
Left (TransportError ConnectErrorCode
 -> Either (TransportError ConnectErrorCode) Stream)
-> TransportError ConnectErrorCode
-> Either (TransportError ConnectErrorCode) Stream
forall a b. (a -> b) -> a -> b
$ ConnectErrorCode -> String -> TransportError ConnectErrorCode
forall error. error -> String -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound String
"handshake failed"))
                    (\()
_ -> MVar (Either (TransportError ConnectErrorCode) Stream)
-> Either (TransportError ConnectErrorCode) Stream -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar (Stream -> Either (TransportError ConnectErrorCode) Stream
forall a b. b -> Either a b
Right Stream
stream))

                IO () -> (Async () -> IO ()) -> IO ()
forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Stream -> MVar () -> IO ()
listenForClose Stream
stream MVar ()
doneMVar) ((Async () -> IO ()) -> IO ()) -> (Async () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Async ()
_ ->
                  MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
doneMVar

      ThreadId
_ <- IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
        ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
          \forall a. IO a -> IO a
unmask ->
            IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
              ( IO () -> IO ()
forall a. IO a -> IO a
unmask (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                  ClientConfig -> (Connection -> IO ()) -> IO ()
forall a. ClientConfig -> (Connection -> IO a) -> IO a
QUIC.Client.run
                    ClientConfig
clientConfig
                    ( \Connection
conn ->
                        IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
                          (Connection -> IO ()
runClient Connection
conn)
                          (forall e a. Exception e => e -> IO a
throwIO @SomeException)
                    )
              )
              SomeException -> IO ()
onExc

      Either (TransportError ConnectErrorCode) Stream
streamOrError <- MVar (Either (TransportError ConnectErrorCode) Stream)
-> IO (Either (TransportError ConnectErrorCode) Stream)
forall a. MVar a -> IO a
takeMVar MVar (Either (TransportError ConnectErrorCode) Stream)
streamMVar

      Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either (TransportError ConnectErrorCode) (MVar (), Stream)
 -> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream)))
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
-> IO (Either (TransportError ConnectErrorCode) (MVar (), Stream))
forall a b. (a -> b) -> a -> b
$ (MVar ()
doneMVar,) (Stream -> (MVar (), Stream))
-> Either (TransportError ConnectErrorCode) Stream
-> Either (TransportError ConnectErrorCode) (MVar (), Stream)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either (TransportError ConnectErrorCode) Stream
streamOrError
 where
  listenForClose :: QUIC.Stream -> MVar () -> IO ()
  listenForClose :: Stream -> MVar () -> IO ()
listenForClose Stream
stream MVar ()
doneMVar =
    Stream -> IO (Either String MessageReceived)
receiveMessage Stream
stream
      IO (Either String MessageReceived)
-> (Either String MessageReceived -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Right MessageReceived
StreamClosed -> do
          MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
        Right (CloseConnection ClientConnId
_) -> do
          MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
        Right MessageReceived
CloseEndPoint -> do
          MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
doneMVar ()
          IO ()
onCloseForcibly
        Either String MessageReceived
other -> IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> (String -> IOError) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IOError
userError (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Unexpected incoming message to client: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Either String MessageReceived -> String
forall a. Show a => a -> String
show Either String MessageReceived
other