{-# LANGUAGE TypeApplications #-}
module ArrayFire.StatisticsSpec where

import ArrayFire    hiding (not)

import Data.Complex
import Test.Hspec

spec :: Spec
spec =
  describe "Statistics spec" $ do
    it "Should find the mean" $ do
      mean (vector @Double 10 [1..]) 0
        `shouldBe`
           5.5
    it "Should find the weighted-mean" $ do
      meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0
        `shouldBe`
           7.0
    it "Should find the variance" $ do
      var (vector @Double 8 [1..8]) False 0
        `shouldBe`
           6.0
    it "Should find the weighted variance" $ do
      varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0
        `shouldBe`
           5.25
    it "Should find the standard deviation" $ do
      stdev (vector @Double 10 (cycle [1,-1])) 0
        `shouldBe`
           1.0
    it "Should find the covariance" $ do
      cov (vector @Double 10 (repeat 1)) (vector @Double 10 (repeat 1)) False
        `shouldBe`
           0.0
    it "Should find the median" $ do
      median (vector @Double 10 [1..]) 0
        `shouldBe`
           5.5
    it "Should find the mean of all elements across all dimensions" $ do
      fst (meanAll (matrix @Double (2,2) [[10,10],[10,10]]))
        `shouldBe`
           10
    it "Should find the weighted mean of all elements across all dimensions" $ do
      fst (meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]]))
        `shouldBe`
           10
    it "Should find the variance of all elements across all dimensions" $ do
      fst (varAll (vector @Double 10 (repeat 10)) False)
        `shouldBe`
           0
    it "Should find the weighted variance of all elements across all dimensions" $ do
      fst (varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10)))
        `shouldBe`
           0
    it "Should find the stdev of all elements across all dimensions" $ do
      fst (stdevAll (vector @Double 10 (repeat 10)))
        `shouldBe`
           0
    it "Should find the median of all elements across all dimensions" $ do
      fst (medianAll (vector @Double 10 [1..]))
        `shouldBe`
           5.5
    it "Should find the correlation coefficient" $ do
      fst (corrCoef (vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] ))
        `shouldBe`
           (-1.0)
    it "Should find the top k elements" $ do
      let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault
      vals `shouldBe` vector @Double 3 [10,9,8]
      indexes `shouldBe` vector @Double 3 [9,8,7]