Skip to content

Commit 8312227

Browse files
committed
Generate matrices that can be added and multiplied
1 parent f89c98c commit 8312227

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/SensStaticHMatrix.hs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
module SensStaticHMatrix where
2020

21-
import GHC.TypeLits (Nat)
21+
import GHC.TypeLits (Nat, SomeNat (SomeNat))
2222
import GHC.TypeLits qualified as TL
2323
import GHC.TypeNats (KnownNat, SNat)
2424
import Numeric.LinearAlgebra.Static
@@ -45,6 +45,11 @@ import Test.QuickCheck.Test (test)
4545
newtype SensStaticHMatrix (x :: Nat) (y :: Nat) (m :: CMetric) (n :: NMetric) (s :: SEnv) =
4646
SensStaticHMatrixUNSAFE {unSensStaticHMatrix :: L x y}
4747

48+
instance (KnownNat x, KnownNat y) => Eq (SensStaticHMatrix x y m n s) where
49+
-- What on earth is this supposed to be? L doesn't even have an Eq
50+
-- instance!
51+
(==) = error "Eq for SensStaticHMatrix unimplemented"
52+
4853

4954
instance (forall senv. KnownNat x, KnownNat y) => Arbitrary (SensStaticHMatrix x y cmetric nmetric s1) where
5055
arbitrary = do
@@ -210,12 +215,31 @@ exampleTwo = do
210215
elems2 <- replicateM ( fromInteger x' * fromInteger y') (arbitrary @Double)
211216
pure (SomeMatrix @x @y $ SensStaticHMatrixUNSAFE $ matrix elems1, SomeMatrix @x @y $ SensStaticHMatrixUNSAFE $ matrix elems2)
212217

218+
exampleThree ::
219+
forall x y c n s1.
220+
(KnownNat x, KnownNat y) =>
221+
Gen (SensStaticHMatrix x y c n s1)
222+
exampleThree = do
223+
elems1 <- replicateM (fromInteger (reflect (Proxy @x))) (arbitrary @Double)
224+
pure (SensStaticHMatrixUNSAFE $ matrix elems1)
213225

214-
test = do
215-
(SomeMatrix m1, SomeMatrix m2) <- generate $ exampleTwo @L2 @Diff
226+
arbitraryKnownNat :: Gen SomeNat
227+
arbitraryKnownNat = do
228+
x' <- arbitrary
229+
reifyNat x' $ \(Proxy @x) ->
230+
pure (SomeNat @x Proxy)
231+
232+
test = generate $ do
233+
SomeNat @x _ <- arbitraryKnownNat
234+
SomeNat @y _ <- arbitraryKnownNat
235+
m1 <- exampleThree @x @y @L2 @Diff
236+
m2 <- exampleThree @x @y @L2 @Diff
216237
pure $ (plus m1 m2) == (plus m1 m2)
217-
-- ^^^
218-
-- Couldn't match type ‘y1’ with ‘y’
219-
-- Expected: SensStaticHMatrix x y L2 Diff s20
220-
-- Actual: SensStaticHMatrix x1 y1 L2 Diff s20
221238

239+
test2 = generate $ do
240+
SomeNat @x _ <- arbitraryKnownNat
241+
SomeNat @y _ <- arbitraryKnownNat
242+
SomeNat @z _ <- arbitraryKnownNat
243+
m1 <- exampleThree @x @y @L2 @Diff
244+
m2 <- exampleThree @y @z @L2 @Diff
245+
pure $ (mult m1 m2) == (mult m1 m2)

0 commit comments

Comments
 (0)