{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes, LambdaCase, TupleSections #-} {-| converting between partial functions and maps. @(for doctest)@ >>> :set +m >>> :set -XLambdaCase >>> :{ let uppercasePartial :: (MonadThrow m) => Char -> m Char -- :: Partial Char Char uppercasePartial = \case 'a' -> return 'A' 'b' -> return 'B' 'z' -> return 'Z' _ -> failed "uppercasePartial" :} a (safely-)partial function is isomorphic with a @Map@: @ 'fromFunctionM' . 'toFunctionM' = 'id' 'toFunctionM' . 'fromFunctionM' = 'id' @ modulo the error thrown. -} module Enumerate.Function.Map where import Enumerate.Types import Enumerate.Function.Extra import Enumerate.Function.Types import Enumerate.Function.Reify import Enumerate.Function.Invert import Control.Monad.Catch (MonadThrow(..)) -- import GHC.TypeLits (Nat, type (^)) import qualified Data.Map as Map import Data.Map (Map) import Control.Exception(PatternMatchFail(..)) import Data.Proxy import Numeric.Natural import Data.Maybe (fromJust) {- | convert a map to a function, if the map is total. >>> let (Just not_) = toFunction (Map.fromList [(False,True),(True,False)]) >>> not_ False True -} toFunction :: (Enumerable a, Ord a) => Map a b -> Maybe (a -> b) toFunction m = if isMapTotal m then Just f else Nothing where f = unsafeToFunction m -- the fromJust is safe when the map is total {-# INLINABLE toFunction #-} {- | convert a (safely-)partial function to a map. lookup failures are 'throwM'n as a 'PatternMatchFail'. >>> let idPartial = toFunctionM (Map.fromList [(True,True)]) >>> idPartial True True >>> idPartial False *** Exception: toFunctionM -} toFunctionM :: (Enumerable a, Ord a) => Map a b -> (Partial a b) toFunctionM m = f where f x = maybe (throwM (PatternMatchFail "toFunctionM")) return (Map.lookup x m) {-# INLINABLE toFunctionM #-} {-| wraps 'Map.lookup' -} unsafeToFunction :: (Ord a) => Map a b -> (a -> b) unsafeToFunction m x = fromJust (Map.lookup x m) {-# INLINABLE unsafeToFunction #-} {-| refines the partial function, if total. >>> :{ let myNotM :: Monad m => Bool -> m Bool myNotM False = return True myNotM True = return False :} >>> let (Just myNot) = isTotalM myNotM >>> myNot False True -} isTotalM :: (Enumerable a, Ord a) => (Partial a b) -> Maybe (a -> b) isTotalM f = (toFunction) (fromFunctionM f) -------------------------------------------------------------------------------- {-| wraps 'Map.lookup' >>> (unsafeFromList [(False,True),(True,False)]) False True >>> (unsafeFromList [(False,True),(True,False)]) True False -} unsafeFromList :: (Ord a) => [(a,b)] -> (a -> b) unsafeFromList = unsafeToFunction . Map.fromList {-# INLINABLE unsafeFromList #-} {-| see 'mappingEnumeratedAt' -} functionEnumerated :: (Enumerable a, Enumerable b, Ord a, Ord b) => [a -> b] functionEnumerated = functions where functions = (unsafeToFunction . Map.fromList) <$> mappings mappings = mappingEnumeratedAt enumerated enumerated -- | @|b| ^ |a|@ functionCardinality :: forall a b proxy. (Enumerable a, Enumerable b) => proxy (a -> b) -> Natural functionCardinality _ = cardinality (Proxy :: Proxy b) ^ cardinality (Proxy :: Proxy a) {-# INLINABLE functionCardinality #-} -- | are all pairs of outputs the same for the same input? (short-ciruits). extensionallyEqual :: (Enumerable a, Eq b) => (a -> b) -> (a -> b) -> Bool extensionallyEqual f g = all ((==) <$> f <*> g) enumerated {-# INLINABLE extensionallyEqual #-} -- | is any pair of outputs different for the same input? (short-ciruits). extensionallyUnequal :: (Enumerable a, Eq b) => (a -> b) -> (a -> b) -> Bool extensionallyUnequal f g = any ((/=) <$> f <*> g) enumerated {-# INLINABLE extensionallyUnequal #-} -- | show all inputs and their outputs, as @unsafeFromList [...]@. functionShowsPrec :: (Enumerable a, Show a, Show b) => Int -> (a -> b) -> ShowS functionShowsPrec = showsPrecWith "unsafeFromList" reifyFunction {-# INLINABLE functionShowsPrec #-} -- | show all inputs and their outputs, as @\case ...@. displayFunction :: (Enumerable a, Show a, Show b) => (a -> b) -> String displayFunction = reifyFunction >>> fmap displayCase >>> ("\\case":) >>> intercalate "\n" where displayCase (x,y) = intercalate " " ["", show x, "->", show y] -- displayPartialFunction -- :: (Enumerable a, Show a, Show b) -- => (Partial a b) -- -> String displayInjective :: (Enumerable a, Ord a, Ord b, Show a, Show b) => (a -> b) -> Maybe String displayInjective f = case isInjective f of Nothing -> Nothing Just{} -> Just (go f) where go = reifyFunction >>> fmap displayCase >>> (["\\case"]++) >>> (++[" _ <- Nothing"]) >>> intercalate "\n" displayCase (x,y) = intercalate " " ["", show y, "<-", show (Just x)] -- displayInjective f = go <$> isInjective f -- -- where -- go = reifyFunction -- >>> fmap displayCase -- >>> ("\\case":) -- >>> intercalate "\n" -- displayCase = \case -- (y, Nothing) -> -- (y, Just x) -> intercalate " " ["", show y, " <- ", show x] {-| @[(a,b)]@ is a mapping, @[[(a,b)]]@ is a list of mappings. >>> let orderingPredicates = mappingEnumeratedAt [LT,EQ,GT] [False,True] >>> print $ length orderingPredicates 8 >>> printMappings $ orderingPredicates <BLANKLINE> (LT,False) (EQ,False) (GT,False) <BLANKLINE> (LT,False) (EQ,False) (GT,True) <BLANKLINE> (LT,False) (EQ,True) (GT,False) <BLANKLINE> (LT,False) (EQ,True) (GT,True) <BLANKLINE> (LT,True) (EQ,False) (GT,False) <BLANKLINE> (LT,True) (EQ,False) (GT,True) <BLANKLINE> (LT,True) (EQ,True) (GT,False) <BLANKLINE> (LT,True) (EQ,True) (GT,True) where the (total) mapping: @ (LT,False) (EQ,False) (GT,True) @ is equivalent to the function: @ \\case LT -> False EQ -> False GT -> True @ -} mappingEnumeratedAt :: [a] -> [b] -> [[(a,b)]] -- TODO diagonalize? performance? mappingEnumeratedAt as bs = go (crossProduct as bs) where go [] = [] go [somePairs] = do pair <- somePairs return$ [pair] go (somePairs:theProduct) = do pair <- somePairs theExponent <- go theProduct return$ pair : theExponent {-| >>> let crossOrderingBoolean = crossProduct [LT,EQ,GT] [False,True] >>> printMappings $ crossOrderingBoolean <BLANKLINE> (LT,False) (LT,True) <BLANKLINE> (EQ,False) (EQ,True) <BLANKLINE> (GT,False) (GT,True) the length of the outer list is the size of the first set, and the length of the inner list is the size of the second set. >>> print $ length crossOrderingBoolean 3 >>> print $ length (head crossOrderingBoolean) 2 -} crossProduct :: [a] -> [b] -> [[(a,b)]] crossProduct [] _ = [] crossProduct (aValue:theDomain) theCodomain = fmap (aValue,) theCodomain : crossProduct theDomain theCodomain