{-# LANGUAGE TypeApplications #-}
module ArrayFire.StatisticsSpec where
import ArrayFire hiding (not)
import Data.Complex
import Test.Hspec
spec :: Spec
spec =
describe "Statistics spec" $ do
it "Should find the mean" $ do
mean (vector @Double 10 [1..]) 0
`shouldBe`
5.5
it "Should find the weighted-mean" $ do
meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0
`shouldBe`
7.0
it "Should find the variance" $ do
var (vector @Double 8 [1..8]) False 0
`shouldBe`
6.0
it "Should find the weighted variance" $ do
varWeighted (vector @Double 8 [1..]) (vector @Double 8 (repeat 1)) 0
`shouldBe`
5.25
it "Should find the standard deviation" $ do
stdev (vector @Double 10 (cycle [1,-1])) 0
`shouldBe`
1.0
it "Should find the covariance" $ do
cov (vector @Double 10 (repeat 1)) (vector @Double 10 (repeat 1)) False
`shouldBe`
0.0
it "Should find the median" $ do
median (vector @Double 10 [1..]) 0
`shouldBe`
5.5
it "Should find the mean of all elements across all dimensions" $ do
fst (meanAll (matrix @Double (2,2) [[10,10],[10,10]]))
`shouldBe`
10
it "Should find the weighted mean of all elements across all dimensions" $ do
fst (meanAllWeighted (matrix @Double (2,2) [[10,10],[10,10]]) (matrix @Double (2,2) [[10,10],[10,10]]))
`shouldBe`
10
it "Should find the variance of all elements across all dimensions" $ do
fst (varAll (vector @Double 10 (repeat 10)) False)
`shouldBe`
0
it "Should find the weighted variance of all elements across all dimensions" $ do
fst (varAllWeighted (vector @Double 10 (repeat 10)) (vector @Double 10 (repeat 10)))
`shouldBe`
0
it "Should find the stdev of all elements across all dimensions" $ do
fst (stdevAll (vector @Double 10 (repeat 10)))
`shouldBe`
0
it "Should find the median of all elements across all dimensions" $ do
fst (medianAll (vector @Double 10 [1..]))
`shouldBe`
5.5
it "Should find the correlation coefficient" $ do
fst (corrCoef (vector @Int 10 [1..] ) ( vector @Int 10 [10,9..] ))
`shouldBe`
(-1.0)
it "Should find the top k elements" $ do
let (vals,indexes) = topk ( vector @Double 10 [1..] ) 3 TopKDefault
vals `shouldBe` vector @Double 3 [10,9,8]
indexes `shouldBe` vector @Double 3 [9,8,7]