{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}

-- | This modules provides a tripwire abstraction. You can use tripwire as a detection mechanism if the code
-- path was executed. Trip a tripwire with 'trip' in the place where you'd like to detect if it was
-- reached. The tripwire can then be checked in the other place in the code using for example 'isTripped' or
-- 'assertNotTripped'.
module Hedgehog.Extras.Test.Tripwire
  (
  -- * Create a tripwire
    Tripwire
  , makeTripwire
  , makeTripwireWithLabel
  -- * Tripwire operations
  , trip
  , trip_
  , isTripped
  , getTripSite
  , resetTripwire
  -- * Assertions
  , assertNotTripped
  , assertTripped
  ) where

import           Control.Monad.IO.Class
import           GHC.Stack

import           Control.Concurrent.MVar
import           Control.Monad
import           Data.IORef
import           Data.Maybe
import           Hedgehog (MonadTest)
import qualified Hedgehog.Extras.Test.Base as H
import qualified Hedgehog.Internal.Property as H
import           Prelude
import           System.IO.Unsafe (unsafePerformIO)

-- | Counter used to allocate consecutive IDs to tripwires
tripwireCounter :: IORef Int
tripwireCounter :: IORef Int
tripwireCounter = IO (IORef Int) -> IORef Int
forall a. IO a -> a
unsafePerformIO (IO (IORef Int) -> IORef Int) -> IO (IORef Int) -> IORef Int
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
{-# NOINLINE tripwireCounter #-}

-- | Represents a tripwire which can be tripped only once. It can be used to detect if a particular code path
-- was reached.
data Tripwire = Tripwire
  { Tripwire -> String
tripwireId :: !String -- ^ a label for identifying the tripwire
  , Tripwire -> MVar CallStack
tripSite :: MVar CallStack -- ^ call stack of the trip site
  }

instance Show Tripwire where
  show :: Tripwire -> String
show Tripwire{String
tripwireId :: Tripwire -> String
tripwireId :: String
tripwireId} = String
"Tripwire " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
tripwireId

-- | Creates a new tripwire
makeTripwire :: MonadIO m => m Tripwire
makeTripwire :: forall (m :: * -> *). MonadIO m => m Tripwire
makeTripwire = IO Tripwire -> m Tripwire
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Tripwire -> m Tripwire) -> IO Tripwire -> m Tripwire
forall a b. (a -> b) -> a -> b
$ do
  Int
id' <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
tripwireCounter ((Int -> Int -> (Int, Int)) -> Int -> (Int, Int)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (,) (Int -> (Int, Int)) -> (Int -> Int) -> Int -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
  String -> MVar CallStack -> Tripwire
Tripwire (Int -> String
forall a. Show a => a -> String
show Int
id') (MVar CallStack -> Tripwire) -> IO (MVar CallStack) -> IO Tripwire
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MVar CallStack)
forall a. IO (MVar a)
newEmptyMVar

-- | Creates a new tripwire with a label, which is visible when 'show'ed: @Tripwire mylabel@
makeTripwireWithLabel :: MonadIO m
                      => String
                      -> m Tripwire
makeTripwireWithLabel :: forall (m :: * -> *). MonadIO m => String -> m Tripwire
makeTripwireWithLabel String
label = IO Tripwire -> m Tripwire
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Tripwire -> m Tripwire) -> IO Tripwire -> m Tripwire
forall a b. (a -> b) -> a -> b
$ do
  String -> MVar CallStack -> Tripwire
Tripwire String
label (MVar CallStack -> Tripwire) -> IO (MVar CallStack) -> IO Tripwire
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MVar CallStack)
forall a. IO (MVar a)
newEmptyMVar

-- | Triggers the tripwire and registers the place of the first trigger. Idempotent.
-- Prints the information in the test log about tripping the tripwire.
trip :: HasCallStack
     => MonadIO m
     => MonadTest m
     => Tripwire
     -> m ()
trip :: forall (m :: * -> *).
(HasCallStack, MonadIO m, MonadTest m) =>
Tripwire -> m ()
trip t :: Tripwire
t@Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
  String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
t String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has been tripped"
  m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (IO Bool -> m Bool) -> IO Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m ()) -> IO Bool -> m ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> CallStack -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar CallStack
tripSite CallStack
HasCallStack => CallStack
callStack

-- | Triggers the tripwire and registers the place of the first trigger. Idempotent. A silent variant of
-- 'trip' which does not require 'MonadTest', but also does not log the information about tripping.
trip_ :: HasCallStack
      => MonadIO m
      => Tripwire
      -> m ()
trip_ :: forall (m :: * -> *). (HasCallStack, MonadIO m) => Tripwire -> m ()
trip_ Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
  m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (IO Bool -> m Bool) -> IO Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m ()) -> IO Bool -> m ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> CallStack -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar CallStack
tripSite CallStack
HasCallStack => CallStack
callStack

-- | Restore tripwire to initial non triggered state
resetTripwire :: MonadIO m
              => Tripwire
              -> m ()
resetTripwire :: forall (m :: * -> *). MonadIO m => Tripwire -> m ()
resetTripwire Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO (Maybe CallStack) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe CallStack) -> IO ()) -> IO (Maybe CallStack) -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> IO (Maybe CallStack)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar CallStack
tripSite

-- | Return the call stack, where the tripwire was tripped - if it was tripped.
getTripSite :: MonadIO m
            => Tripwire
            -> m (Maybe CallStack)
getTripSite :: forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO (Maybe CallStack) -> m (Maybe CallStack)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe CallStack) -> m (Maybe CallStack))
-> IO (Maybe CallStack) -> m (Maybe CallStack)
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> IO (Maybe CallStack)
forall a. MVar a -> IO (Maybe a)
tryReadMVar MVar CallStack
tripSite

-- | Check if the tripwire was tripped.
isTripped :: MonadIO m
          => Tripwire
          -> m Bool
isTripped :: forall (m :: * -> *). MonadIO m => Tripwire -> m Bool
isTripped Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar CallStack -> IO Bool
forall a. MVar a -> IO Bool
isEmptyMVar MVar CallStack
tripSite

-- | Fails the test if the tripwire was triggered. Prints the call stack where the tripwire was triggered.
assertNotTripped :: HasCallStack
                 => MonadTest m
                 => MonadIO m
                 => Tripwire
                 -> m ()
assertNotTripped :: forall (m :: * -> *).
(HasCallStack, MonadTest m, MonadIO m) =>
Tripwire -> m ()
assertNotTripped Tripwire
tripwire = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Maybe CallStack
mTripSite <- Tripwire -> m (Maybe CallStack)
forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire
tripwire
  Maybe CallStack -> (CallStack -> m Any) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe CallStack
mTripSite ((CallStack -> m Any) -> m ()) -> (CallStack -> m Any) -> m ()
forall a b. (a -> b) -> a -> b
$ \CallStack
cs -> do
    String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
tripwire String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has been tripped at: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CallStack -> String
prettyCallStack CallStack
cs
    m Any
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
H.failure

-- | Fails the test if the tripwire was not triggered yet.
assertTripped :: HasCallStack
              => MonadTest m
              => MonadIO m
              => Tripwire
              -> m ()
assertTripped :: forall (m :: * -> *).
(HasCallStack, MonadTest m, MonadIO m) =>
Tripwire -> m ()
assertTripped Tripwire
tripwire = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
  Maybe CallStack
mTripSite <- Tripwire -> m (Maybe CallStack)
forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire
tripwire
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe CallStack -> Bool
forall a. Maybe a -> Bool
isNothing Maybe CallStack
mTripSite) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
tripwire String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" was not tripped"
    m ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
H.failure