{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

{- |
Module      : Web.Scotty.Session
Copyright   : (c) 2025 Tushar Adhatrao,
              (c) 2025 Marco Zocca

License     : BSD-3-Clause
Maintainer  :
Stability   : experimental
Portability : GHC

This module provides session management functionality for Scotty web applications.

==Example usage:

@
\{\-\# LANGUAGE OverloadedStrings \#\-\}

import Web.Scotty
import Web.Scotty.Session
import Control.Monad.IO.Class (liftIO)
main :: IO ()
main = do
    -- Create a session jar
    sessionJar <- createSessionJar
    scotty 3000 $ do
        -- Route to create a session
        get "/create" $ do
            sess <- createUserSession sessionJar "user data"
            html $ "Session created with ID: " <> sessId sess
        -- Route to read a session
        get "/read" $ do
            eSession <- getUserSession sessionJar
            case eSession of
                Left _-> html "No session found or session expired."
                Right sess -> html $ "Session content: " <> sessContent sess
@
-}
module Web.Scotty.Session (
    Session (..),
    SessionId,
    SessionJar,
    SessionStatus,

    -- * Create Session Jar
    createSessionJar,

    -- * Create session
    createUserSession,
    createSession,

    -- * Read session
    readUserSession,
    readSession,
    getUserSession,
    getSession,

    -- * Add session
    addSession,

    -- * Delte session
    deleteSession,

    -- * Helper functions
    maintainSessions,
) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.IO.Class (MonadIO (..))
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
import System.Random (randomRIO)
import Web.Scotty.Action (ActionT)
import Web.Scotty.Cookie

-- | Type alias for session identifiers.
type SessionId = T.Text

-- | Status of a session lookup.
data SessionStatus = SessionNotFound | SessionExpired
  deriving (Int -> SessionStatus -> ShowS
[SessionStatus] -> ShowS
SessionStatus -> String
(Int -> SessionStatus -> ShowS)
-> (SessionStatus -> String)
-> ([SessionStatus] -> ShowS)
-> Show SessionStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SessionStatus -> ShowS
showsPrec :: Int -> SessionStatus -> ShowS
$cshow :: SessionStatus -> String
show :: SessionStatus -> String
$cshowList :: [SessionStatus] -> ShowS
showList :: [SessionStatus] -> ShowS
Show, SessionStatus -> SessionStatus -> Bool
(SessionStatus -> SessionStatus -> Bool)
-> (SessionStatus -> SessionStatus -> Bool) -> Eq SessionStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SessionStatus -> SessionStatus -> Bool
== :: SessionStatus -> SessionStatus -> Bool
$c/= :: SessionStatus -> SessionStatus -> Bool
/= :: SessionStatus -> SessionStatus -> Bool
Eq)

-- | Represents a session containing an ID, expiration time, and content.
data Session a = Session
    { forall a. Session a -> SessionId
sessId :: SessionId
    -- ^ Unique identifier for the session.
    , forall a. Session a -> UTCTime
sessExpiresAt :: UTCTime
    -- ^ Expiration time of the session.
    , forall a. Session a -> a
sessContent :: a
    -- ^ Content stored in the session.
    }
    deriving (Session a -> Session a -> Bool
(Session a -> Session a -> Bool)
-> (Session a -> Session a -> Bool) -> Eq (Session a)
forall a. Eq a => Session a -> Session a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Session a -> Session a -> Bool
== :: Session a -> Session a -> Bool
$c/= :: forall a. Eq a => Session a -> Session a -> Bool
/= :: Session a -> Session a -> Bool
Eq, Int -> Session a -> ShowS
[Session a] -> ShowS
Session a -> String
(Int -> Session a -> ShowS)
-> (Session a -> String)
-> ([Session a] -> ShowS)
-> Show (Session a)
forall a. Show a => Int -> Session a -> ShowS
forall a. Show a => [Session a] -> ShowS
forall a. Show a => Session a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> Session a -> ShowS
showsPrec :: Int -> Session a -> ShowS
$cshow :: forall a. Show a => Session a -> String
show :: Session a -> String
$cshowList :: forall a. Show a => [Session a] -> ShowS
showList :: [Session a] -> ShowS
Show)

-- | Type for session storage, a transactional variable containing a map of session IDs to sessions.
type SessionJar a = TVar (HM.HashMap SessionId (Session a))

-- | Creates a new session jar and starts a background thread to maintain it.
createSessionJar :: IO (SessionJar a)
createSessionJar :: forall a. IO (SessionJar a)
createSessionJar = do
    SessionJar a
storage <- HashMap SessionId (Session a) -> IO (SessionJar a)
forall a. a -> IO (TVar a)
newTVarIO HashMap SessionId (Session a)
forall k v. HashMap k v
HM.empty
    ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ SessionJar a -> IO ()
forall a. SessionJar a -> IO ()
maintainSessions SessionJar a
storage
    SessionJar a -> IO (SessionJar a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return SessionJar a
storage

-- | Continuously removes expired sessions from the session jar.
maintainSessions :: SessionJar a -> IO ()
maintainSessions :: forall a. SessionJar a -> IO ()
maintainSessions SessionJar a
sessionJar =
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        UTCTime
now <- IO UTCTime
getCurrentTime
        let stillValid :: Session a -> Bool
stillValid Session a
sess = Session a -> UTCTime
forall a. Session a -> UTCTime
sessExpiresAt Session a
sess UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
> UTCTime
now
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ SessionJar a
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar SessionJar a
sessionJar ((HashMap SessionId (Session a) -> HashMap SessionId (Session a))
 -> STM ())
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a b. (a -> b) -> a -> b
$ \HashMap SessionId (Session a)
m -> (Session a -> Bool)
-> HashMap SessionId (Session a) -> HashMap SessionId (Session a)
forall v k. (v -> Bool) -> HashMap k v -> HashMap k v
HM.filter Session a -> Bool
forall {a}. Session a -> Bool
stillValid HashMap SessionId (Session a)
m
        Int -> IO ()
threadDelay Int
1000000
        

-- | Adds or overwrites a new session to the session jar.
addSession :: SessionJar a -> Session a -> IO ()
addSession :: forall a. SessionJar a -> Session a -> IO ()
addSession SessionJar a
sessionJar Session a
sess =
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ SessionJar a
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar SessionJar a
sessionJar ((HashMap SessionId (Session a) -> HashMap SessionId (Session a))
 -> STM ())
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a b. (a -> b) -> a -> b
$ \HashMap SessionId (Session a)
m -> SessionId
-> Session a
-> HashMap SessionId (Session a)
-> HashMap SessionId (Session a)
forall k v. Hashable k => k -> v -> HashMap k v -> HashMap k v
HM.insert (Session a -> SessionId
forall a. Session a -> SessionId
sessId Session a
sess) Session a
sess HashMap SessionId (Session a)
m

-- | Retrieves a session by its ID from the session jar.
getSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Either SessionStatus (Session a))
getSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a
-> SessionId -> ActionT m (Either SessionStatus (Session a))
getSession SessionJar a
sessionJar SessionId
sId =
    do
        HashMap SessionId (Session a)
s <- IO (HashMap SessionId (Session a))
-> ActionT m (HashMap SessionId (Session a))
forall a. IO a -> ActionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (HashMap SessionId (Session a))
 -> ActionT m (HashMap SessionId (Session a)))
-> IO (HashMap SessionId (Session a))
-> ActionT m (HashMap SessionId (Session a))
forall a b. (a -> b) -> a -> b
$ SessionJar a -> IO (HashMap SessionId (Session a))
forall a. TVar a -> IO a
readTVarIO SessionJar a
sessionJar
        case SessionId -> HashMap SessionId (Session a) -> Maybe (Session a)
forall k v. Hashable k => k -> HashMap k v -> Maybe v
HM.lookup SessionId
sId HashMap SessionId (Session a)
s of
          Maybe (Session a)
Nothing -> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a. a -> ActionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionStatus (Session a)
 -> ActionT m (Either SessionStatus (Session a)))
-> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a b. (a -> b) -> a -> b
$ SessionStatus -> Either SessionStatus (Session a)
forall a b. a -> Either a b
Left SessionStatus
SessionNotFound
          Just Session a
sess -> do 
            UTCTime
now <- IO UTCTime -> ActionT m UTCTime
forall a. IO a -> ActionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
            if Session a -> UTCTime
forall a. Session a -> UTCTime
sessExpiresAt Session a
sess UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
< UTCTime
now
              then SessionJar a -> SessionId -> ActionT m ()
forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> SessionId -> ActionT m ()
deleteSession SessionJar a
sessionJar (Session a -> SessionId
forall a. Session a -> SessionId
sessId Session a
sess) ActionT m ()
-> ActionT m (Either SessionStatus (Session a))
-> ActionT m (Either SessionStatus (Session a))
forall a b. ActionT m a -> ActionT m b -> ActionT m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a. a -> ActionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SessionStatus -> Either SessionStatus (Session a)
forall a b. a -> Either a b
Left SessionStatus
SessionExpired)
              else Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a. a -> ActionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionStatus (Session a)
 -> ActionT m (Either SessionStatus (Session a)))
-> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a b. (a -> b) -> a -> b
$ Session a -> Either SessionStatus (Session a)
forall a b. b -> Either a b
Right Session a
sess

-- | Deletes a session by its ID from the session jar.
deleteSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m ()
deleteSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> SessionId -> ActionT m ()
deleteSession SessionJar a
sessionJar SessionId
sId =
    IO () -> ActionT m ()
forall a. IO a -> ActionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ActionT m ()) -> IO () -> ActionT m ()
forall a b. (a -> b) -> a -> b
$
        STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
            SessionJar a
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar SessionJar a
sessionJar ((HashMap SessionId (Session a) -> HashMap SessionId (Session a))
 -> STM ())
-> (HashMap SessionId (Session a) -> HashMap SessionId (Session a))
-> STM ()
forall a b. (a -> b) -> a -> b
$
                SessionId
-> HashMap SessionId (Session a) -> HashMap SessionId (Session a)
forall k v. Hashable k => k -> HashMap k v -> HashMap k v
HM.delete SessionId
sId

{- | Retrieves the current user's session based on the "sess_id" cookie.
| Returns `Left SessionStatus` if the session is expired or does not exist.
-}
getUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Either SessionStatus (Session a))
getUserSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> ActionT m (Either SessionStatus (Session a))
getUserSession SessionJar a
sessionJar = do
    SessionId -> ActionT m (Maybe SessionId)
forall (m :: * -> *).
Monad m =>
SessionId -> ActionT m (Maybe SessionId)
getCookie SessionId
"sess_id" ActionT m (Maybe SessionId)
-> (Maybe SessionId
    -> ActionT m (Either SessionStatus (Session a)))
-> ActionT m (Either SessionStatus (Session a))
forall a b. ActionT m a -> (a -> ActionT m b) -> ActionT m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe SessionId
Nothing -> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a. a -> ActionT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either SessionStatus (Session a)
 -> ActionT m (Either SessionStatus (Session a)))
-> Either SessionStatus (Session a)
-> ActionT m (Either SessionStatus (Session a))
forall a b. (a -> b) -> a -> b
$ SessionStatus -> Either SessionStatus (Session a)
forall a b. a -> Either a b
Left SessionStatus
SessionNotFound
        Just SessionId
sid -> SessionId -> ActionT m (Either SessionStatus (Session a))
lookupSession SessionId
sid
  where
    lookupSession :: SessionId -> ActionT m (Either SessionStatus (Session a))
lookupSession = SessionJar a
-> SessionId -> ActionT m (Either SessionStatus (Session a))
forall (m :: * -> *) a.
MonadIO m =>
SessionJar a
-> SessionId -> ActionT m (Either SessionStatus (Session a))
getSession SessionJar a
sessionJar

-- | Reads the content of a session by its ID.
readSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Either SessionStatus a)
readSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> SessionId -> ActionT m (Either SessionStatus a)
readSession SessionJar a
sessionJar SessionId
sId = do
    Either SessionStatus (Session a)
res <- SessionJar a
-> SessionId -> ActionT m (Either SessionStatus (Session a))
forall (m :: * -> *) a.
MonadIO m =>
SessionJar a
-> SessionId -> ActionT m (Either SessionStatus (Session a))
getSession SessionJar a
sessionJar SessionId
sId
    Either SessionStatus a -> ActionT m (Either SessionStatus a)
forall a. a -> ActionT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SessionStatus a -> ActionT m (Either SessionStatus a))
-> Either SessionStatus a -> ActionT m (Either SessionStatus a)
forall a b. (a -> b) -> a -> b
$ Session a -> a
forall a. Session a -> a
sessContent (Session a -> a)
-> Either SessionStatus (Session a) -> Either SessionStatus a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either SessionStatus (Session a)
res

-- | Reads the content of the current user's session.
readUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Either SessionStatus a)
readUserSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> ActionT m (Either SessionStatus a)
readUserSession SessionJar a
sessionJar = do
    Either SessionStatus (Session a)
res <- SessionJar a -> ActionT m (Either SessionStatus (Session a))
forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> ActionT m (Either SessionStatus (Session a))
getUserSession SessionJar a
sessionJar
    Either SessionStatus a -> ActionT m (Either SessionStatus a)
forall a. a -> ActionT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SessionStatus a -> ActionT m (Either SessionStatus a))
-> Either SessionStatus a -> ActionT m (Either SessionStatus a)
forall a b. (a -> b) -> a -> b
$ Session a -> a
forall a. Session a -> a
sessContent (Session a -> a)
-> Either SessionStatus (Session a) -> Either SessionStatus a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Either SessionStatus (Session a)
res

-- | The time-to-live for sessions, in seconds.
sessionTTL :: NominalDiffTime
sessionTTL :: NominalDiffTime
sessionTTL = NominalDiffTime
36000 -- in seconds

-- | Creates a new session for a user, storing the content and setting a cookie.
createUserSession :: (MonadIO m) => 
    SessionJar a -- ^ SessionJar, which can be created by createSessionJar
    -> Maybe Int  -- ^ Optional expiration time (in seconds)
    -> a          -- ^ Content
    -> ActionT m (Session a)
createUserSession :: forall (m :: * -> *) a.
MonadIO m =>
SessionJar a -> Maybe Int -> a -> ActionT m (Session a)
createUserSession SessionJar a
sessionJar Maybe Int
mbExpirationTime a
content = do
    Session a
sess <- IO (Session a) -> ActionT m (Session a)
forall a. IO a -> ActionT m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Session a) -> ActionT m (Session a))
-> IO (Session a) -> ActionT m (Session a)
forall a b. (a -> b) -> a -> b
$ SessionJar a -> Maybe Int -> a -> IO (Session a)
forall a. SessionJar a -> Maybe Int -> a -> IO (Session a)
createSession SessionJar a
sessionJar Maybe Int
mbExpirationTime a
content
    SessionId -> SessionId -> ActionT m ()
forall (m :: * -> *).
MonadIO m =>
SessionId -> SessionId -> ActionT m ()
setSimpleCookie SessionId
"sess_id" (Session a -> SessionId
forall a. Session a -> SessionId
sessId Session a
sess)
    Session a -> ActionT m (Session a)
forall a. a -> ActionT m a
forall (m :: * -> *) a. Monad m => a -> m a
return Session a
sess

-- | Creates a new session with a generated ID, sets its expiration, 
-- | and adds it to the session jar.
createSession :: SessionJar a -> Maybe Int -> a -> IO (Session a)
createSession :: forall a. SessionJar a -> Maybe Int -> a -> IO (Session a)
createSession SessionJar a
sessionJar Maybe Int
mbExpirationTime a
content = do
    SessionId
sId <- IO SessionId -> IO SessionId
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO SessionId -> IO SessionId) -> IO SessionId -> IO SessionId
forall a b. (a -> b) -> a -> b
$ String -> SessionId
T.pack (String -> SessionId) -> IO String -> IO SessionId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO Char -> IO String
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
32 ((Char, Char) -> IO Char
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Char
'a', Char
'z'))
    UTCTime
now <- IO UTCTime
getCurrentTime
    let expiresAt :: UTCTime
expiresAt = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime (NominalDiffTime
-> (Int -> NominalDiffTime) -> Maybe Int -> NominalDiffTime
forall b a. b -> (a -> b) -> Maybe a -> b
maybe NominalDiffTime
sessionTTL Int -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Maybe Int
mbExpirationTime) UTCTime
now
        sess :: Session a
sess = SessionId -> UTCTime -> a -> Session a
forall a. SessionId -> UTCTime -> a -> Session a
Session SessionId
sId UTCTime
expiresAt a
content
    IO () -> IO ()
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ SessionJar a -> Session a -> IO ()
forall a. SessionJar a -> Session a -> IO ()
addSession SessionJar a
sessionJar Session a
sess
    Session a -> IO (Session a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Session a -> IO (Session a)) -> Session a -> IO (Session a)
forall a b. (a -> b) -> a -> b
$ SessionId -> UTCTime -> a -> Session a
forall a. SessionId -> UTCTime -> a -> Session a
Session SessionId
sId UTCTime
expiresAt a
content