{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
module Hedgehog.Extras.Test.Tripwire
(
Tripwire
, makeTripwire
, makeTripwireWithLabel
, trip
, trip_
, isTripped
, getTripSite
, resetTripwire
, assertNotTripped
, assertTripped
) where
import Control.Monad.IO.Class
import GHC.Stack
import Control.Concurrent.MVar
import Control.Monad
import Data.IORef
import Data.Maybe
import Hedgehog (MonadTest)
import qualified Hedgehog.Extras.Test.Base as H
import qualified Hedgehog.Internal.Property as H
import Prelude
import System.IO.Unsafe (unsafePerformIO)
tripwireCounter :: IORef Int
tripwireCounter :: IORef Int
tripwireCounter = IO (IORef Int) -> IORef Int
forall a. IO a -> a
unsafePerformIO (IO (IORef Int) -> IORef Int) -> IO (IORef Int) -> IORef Int
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
{-# NOINLINE tripwireCounter #-}
data Tripwire = Tripwire
{ Tripwire -> String
tripwireId :: !String
, Tripwire -> MVar CallStack
tripSite :: MVar CallStack
}
instance Show Tripwire where
show :: Tripwire -> String
show Tripwire{String
tripwireId :: Tripwire -> String
tripwireId :: String
tripwireId} = String
"Tripwire " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
tripwireId
makeTripwire :: MonadIO m => m Tripwire
makeTripwire :: forall (m :: * -> *). MonadIO m => m Tripwire
makeTripwire = IO Tripwire -> m Tripwire
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Tripwire -> m Tripwire) -> IO Tripwire -> m Tripwire
forall a b. (a -> b) -> a -> b
$ do
Int
id' <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
tripwireCounter ((Int -> Int -> (Int, Int)) -> Int -> (Int, Int)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (,) (Int -> (Int, Int)) -> (Int -> Int) -> Int -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
String -> MVar CallStack -> Tripwire
Tripwire (Int -> String
forall a. Show a => a -> String
show Int
id') (MVar CallStack -> Tripwire) -> IO (MVar CallStack) -> IO Tripwire
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MVar CallStack)
forall a. IO (MVar a)
newEmptyMVar
makeTripwireWithLabel :: MonadIO m
=> String
-> m Tripwire
makeTripwireWithLabel :: forall (m :: * -> *). MonadIO m => String -> m Tripwire
makeTripwireWithLabel String
label = IO Tripwire -> m Tripwire
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Tripwire -> m Tripwire) -> IO Tripwire -> m Tripwire
forall a b. (a -> b) -> a -> b
$ do
String -> MVar CallStack -> Tripwire
Tripwire String
label (MVar CallStack -> Tripwire) -> IO (MVar CallStack) -> IO Tripwire
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MVar CallStack)
forall a. IO (MVar a)
newEmptyMVar
trip :: HasCallStack
=> MonadIO m
=> MonadTest m
=> Tripwire
-> m ()
trip :: forall (m :: * -> *).
(HasCallStack, MonadIO m, MonadTest m) =>
Tripwire -> m ()
trip t :: Tripwire
t@Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
t String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has been tripped"
m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (IO Bool -> m Bool) -> IO Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m ()) -> IO Bool -> m ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> CallStack -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar CallStack
tripSite CallStack
HasCallStack => CallStack
callStack
trip_ :: HasCallStack
=> MonadIO m
=> Tripwire
-> m ()
trip_ :: forall (m :: * -> *). (HasCallStack, MonadIO m) => Tripwire -> m ()
trip_ Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
m Bool -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (m Bool -> m ()) -> (IO Bool -> m Bool) -> IO Bool -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m ()) -> IO Bool -> m ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> CallStack -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar CallStack
tripSite CallStack
HasCallStack => CallStack
callStack
resetTripwire :: MonadIO m
=> Tripwire
-> m ()
resetTripwire :: forall (m :: * -> *). MonadIO m => Tripwire -> m ()
resetTripwire Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO (Maybe CallStack) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Maybe CallStack) -> IO ()) -> IO (Maybe CallStack) -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> IO (Maybe CallStack)
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar CallStack
tripSite
getTripSite :: MonadIO m
=> Tripwire
-> m (Maybe CallStack)
getTripSite :: forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO (Maybe CallStack) -> m (Maybe CallStack)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe CallStack) -> m (Maybe CallStack))
-> IO (Maybe CallStack) -> m (Maybe CallStack)
forall a b. (a -> b) -> a -> b
$ MVar CallStack -> IO (Maybe CallStack)
forall a. MVar a -> IO (Maybe a)
tryReadMVar MVar CallStack
tripSite
isTripped :: MonadIO m
=> Tripwire
-> m Bool
isTripped :: forall (m :: * -> *). MonadIO m => Tripwire -> m Bool
isTripped Tripwire{MVar CallStack
tripSite :: Tripwire -> MVar CallStack
tripSite :: MVar CallStack
tripSite} = IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> IO Bool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVar CallStack -> IO Bool
forall a. MVar a -> IO Bool
isEmptyMVar MVar CallStack
tripSite
assertNotTripped :: HasCallStack
=> MonadTest m
=> MonadIO m
=> Tripwire
-> m ()
assertNotTripped :: forall (m :: * -> *).
(HasCallStack, MonadTest m, MonadIO m) =>
Tripwire -> m ()
assertNotTripped Tripwire
tripwire = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
Maybe CallStack
mTripSite <- Tripwire -> m (Maybe CallStack)
forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire
tripwire
Maybe CallStack -> (CallStack -> m Any) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe CallStack
mTripSite ((CallStack -> m Any) -> m ()) -> (CallStack -> m Any) -> m ()
forall a b. (a -> b) -> a -> b
$ \CallStack
cs -> do
String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
tripwire String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" has been tripped at: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CallStack -> String
prettyCallStack CallStack
cs
m Any
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
H.failure
assertTripped :: HasCallStack
=> MonadTest m
=> MonadIO m
=> Tripwire
-> m ()
assertTripped :: forall (m :: * -> *).
(HasCallStack, MonadTest m, MonadIO m) =>
Tripwire -> m ()
assertTripped Tripwire
tripwire = (HasCallStack => m ()) -> m ()
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack ((HasCallStack => m ()) -> m ()) -> (HasCallStack => m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ do
Maybe CallStack
mTripSite <- Tripwire -> m (Maybe CallStack)
forall (m :: * -> *). MonadIO m => Tripwire -> m (Maybe CallStack)
getTripSite Tripwire
tripwire
Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe CallStack -> Bool
forall a. Maybe a -> Bool
isNothing Maybe CallStack
mTripSite) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
String -> m ()
forall (m :: * -> *). (MonadTest m, HasCallStack) => String -> m ()
H.note_ (String -> m ()) -> String -> m ()
forall a b. (a -> b) -> a -> b
$ Tripwire -> String
forall a. Show a => a -> String
show Tripwire
tripwire String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" was not tripped"
m ()
forall (m :: * -> *) a. (MonadTest m, HasCallStack) => m a
H.failure