diff --git a/src/Data/Array/Accelerate.hs b/src/Data/Array/Accelerate.hs index 929f2fe21..7936caa68 100644 --- a/src/Data/Array/Accelerate.hs +++ b/src/Data/Array/Accelerate.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE CPP #-} +{-# LANGUAGE ExplicitNamespaces #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -411,7 +413,10 @@ module Data.Array.Accelerate ( -- --------------------------------------------------------------------------- -- * Useful re-exports - (.), ($), (&), flip, error, undefined, const, otherwise, + (.), ($), (&), flip, error, undefined, const, id, otherwise, +#if __GLASGOW_HASKELL__ >= 904 + type (~), +#endif Show, Generic, HasCallStack, fromString, -- -XOverloadedStrings fromListN, -- -XOverloadedLists @@ -463,7 +468,10 @@ import qualified Data.Array.Accelerate.Sugar.Array as S import qualified Data.Array.Accelerate.Sugar.Shape as S import Data.Function ( (&) ) -import Prelude ( (.), ($), Char, Show, flip, undefined, error, const, otherwise ) +#if __GLASGOW_HASKELL__ >= 904 +import Data.Type.Equality ( type (~) ) +#endif +import Prelude ( (.), ($), Char, Show, flip, undefined, error, const, id, otherwise ) import GHC.Exts ( fromListN, fromString ) import GHC.Generics ( Generic ) diff --git a/src/Data/Array/Accelerate/Classes/Eq.hs b/src/Data/Array/Accelerate/Classes/Eq.hs index 6985facdd..ca5ece850 100644 --- a/src/Data/Array/Accelerate/Classes/Eq.hs +++ b/src/Data/Array/Accelerate/Classes/Eq.hs @@ -119,6 +119,15 @@ instance Eq Z where _ == _ = True_ _ /= _ = False_ +instance (Shape sh) => Eq (Any sh) where + _ == _ = True_ + _ /= _ = False_ + +instance Eq All where + _ == _ = True_ + _ /= _ = False_ + + -- Instances of 'Prelude.Eq' don't make sense with the standard signatures as -- the return type is fixed to 'Bool'. This instance is provided to provide -- a useful error message. @@ -203,7 +212,7 @@ runQ $ do ts <- mapM mkTup [2..16] return $ concat (concat [is,fs,ns,cs,ts]) -instance Eq sh => Eq (sh :. Int) where +instance (Eq sh, Eq i) => Eq (sh :. i) where x == y = indexHead x == indexHead y && indexTail x == indexTail y x /= y = indexHead x /= indexHead y || indexTail x /= indexTail y diff --git a/src/Data/Array/Accelerate/Control/Monad.hs b/src/Data/Array/Accelerate/Control/Monad.hs index 89a72ca9f..a461bf45d 100644 --- a/src/Data/Array/Accelerate/Control/Monad.hs +++ b/src/Data/Array/Accelerate/Control/Monad.hs @@ -25,6 +25,7 @@ module Data.Array.Accelerate.Control.Monad ( -- ** Basic functions (=<<), (>>), (>=>), (<=<), + join, -- ** Conditional execution of monadic expressions when, unless, @@ -40,7 +41,7 @@ import Data.Array.Accelerate.Language import Data.Array.Accelerate.Sugar.Elt import Data.Array.Accelerate.Smart -import Prelude ( Bool, flip ) +import Prelude ( Bool, flip, id ) -- | The 'Monad' class is used for scalar types which can be sequenced. @@ -134,6 +135,19 @@ infixr 1 <=< -> (Exp a -> Exp (m c)) (<=<) = flip (>=>) +-- | The 'join' function is the conventional monad join operator. It +-- is used to remove one level of monadic structure, projecting its +-- bound argument into the outer level. +-- +-- \'@'join' bss@\' can be understood as the @do@ expression +-- +-- @ +-- do bs <- bss +-- bs +-- @ +-- +join :: (Monad m, Elt a, Elt (m a), Elt (m (m a))) => Exp (m (m a)) -> Exp (m a) +join = (>>= id) -- | Conditional execution of a monadic expression -- diff --git a/src/Data/Array/Accelerate/Data/Maybe.hs b/src/Data/Array/Accelerate/Data/Maybe.hs index 305da9a71..60bcea2df 100644 --- a/src/Data/Array/Accelerate/Data/Maybe.hs +++ b/src/Data/Array/Accelerate/Data/Maybe.hs @@ -130,7 +130,7 @@ instance Ord a => Ord (Maybe a) where go Nothing_ Just_{} = LT_ go Just_{} Nothing_{} = GT_ -instance (Monoid (Exp a), Elt a) => Monoid (Exp (Maybe a)) where +instance (Semigroup (Exp a), Elt a) => Monoid (Exp (Maybe a)) where mempty = Nothing_ instance (Semigroup (Exp a), Elt a) => Semigroup (Exp (Maybe a)) where