module TestPartitions (testPartitions) where

import Data.Maybe (fromJust)
import qualified Data.Vector as V
import GHC.Natural (wordToNatural)
import RandomCycle.List as RL
import RandomCycle.Vector as RV
import System.Random.Stateful (mkStdGen, runStateGen_)
import Test.Tasty (TestTree)
import Test.Tasty.QuickCheck

{- Top-level -}

testPartitions :: TestTree
testPartitions = testProperties "Partitions" prop_partitions

{- Partition Properties -}

prop_partitions :: [(String, Property)]
prop_partitions = props1
  where
    props1 =
      [ ("isPartitionV", property prop_isPartitionV),
        ("isPartitionL", property prop_isPartitionL),
        ("breaksCorrectL", property prop_breaksCorrectL),
        ("breaksCorrectV", property prop_breaksCorrectV),
        -- With thinning
        ("isPartitionThinL", property prop_isPartitionThinL)
      ]

-- Vectors

prop_isPartitionV :: [Int] -> Property
prop_isPartitionV xs = V.toList xs' === xs
  where
    xs' = V.concat (run $ RV.uniformPartition $ V.fromList xs)

-- Are the breaks in the right places?
prop_breaksCorrectV :: Word -> NBits -> Property
prop_breaksCorrectV bs (NBits n) = pl === pl'
  where
    pl = RL.partitionLengths bs n
    pl' = map length $ RV.partitionFromBits (wordToNatural bs) $ V.fromList [1 .. n]

-- Lists

-- Is it a partition?
prop_isPartitionL :: [Int] -> Property
prop_isPartitionL xs = concat (run $ RL.uniformPartition xs) === xs

-- Are the breaks in the right places?
prop_breaksCorrectL :: Word -> NBits -> Property
prop_breaksCorrectL bs (NBits n) = pl === pl'
  where
    pl = RL.partitionLengths bs n
    pl' = map length $ RL.partitionFromBits (wordToNatural bs) [1 .. n]

-- Is it a partition and does it follow the rule?
-- NonNegative used to generate a list that with reasonable probability will
-- satisfy the rules.
prop_isPartitionThinL :: NonNegative Int -> Property
prop_isPartitionThinL (NonNegative i) = propP .&&. propR
  where
    xs = [0 .. (i + 2)]
    r = (>= 2) . sum
    -- NOTE: at the moment, we don't want to test
    -- "no match found" cases here.
    ps = fromJust $ run $ RL.uniformPartitionThin maxit r xs
    propP = concat ps === xs
    propR = all r ps === True

{- Utilities -}

-- Max iterations for all thin tests
maxit = 1000

run = runStateGen_ (mkStdGen 1305)

-- Int restricted to range of 1..64 bits
newtype NBits = NBits Int deriving (Show)

instance Arbitrary NBits where
  arbitrary = NBits <$> choose (1, 64)