Saturday, January 26, 2008

Safety and Optimisation: joinMaybes'

In a recent blog post by Conal, he introduced the joinMaybes' function, defined as:


filterMP' :: MonadPlus m => (a -> Bool) -> m a -> m a
filterMP' p = (>>= f)
where
f a | p a = return a
| otherwise = mzero

joinMaybes' :: MonadPlus m => m (Maybe a) -> m a
joinMaybes' = liftM fromJust . filterMP' isJust


He laments that the use of isJust and fromJust mean that his code will run slower (having two Just tests), and that an automated checker such as Catch won't be able to check it successfully. Fortunately, Catch can check the code perfectly, and Supero can optimise the code perfectly. As such, this simple definition is perfectly fine from all points of view. I'm going to go through the checking with Catch in some detail, and if anyone wants, I'll post another article on the optimisation with Supero.

Checking With Catch

To simplify things, I'm going to work only in the [] monad, so here is a new variant of the code:


filterMP' :: (a -> Bool) -> [a] -> [a]
filterMP' p = concatMap f
where
f a | p a = [a]
| otherwise = []

joinMaybes' :: MonadPlus m => m (Maybe a) -> m a
joinMaybes' = map fromJust . filterMP' isJust


Catch would remove the dictionaries before starting, so would accept the original code unmodified. The first thing Catch wold do is reduce this fragment to first-order. The end translation would be:


filterMP_isJust x = concatMap_f x
concatMap_f [] = []
concatMap_f (x:xs) = (if isJust x then [x] else []) ++ concatMap_f xs

joinMaybes x = map_fromJust (filterMP_isJust x)
map_fromJust [] = []
map_fromJust (Just x:xs) = x : map_fromJust xs
map_fromJust (Nothing:xs) = error "Pattern match error"


I've also taken the liberty of inlining the otherwise, and used pattern matching rather than case expressions. Catch will take care of those details for us, but the code is a little easier to follow without them in. Now Catch can begin the checking process.

Catch first decideds that if map_fromJust is passed a list matching the pattern (Nothing:_), it will crash, and annotates the precondition of map_fromJust as being either the input list is [] or (Just _:_). It then spots the recursive call within map_fromJust, and determines that the revised precondition should be that the input list is a list, or any length, whose elements are all Just constructed (we call this condition P).

Having determined the precondition on map_fromJust, it uses that within joinMaybes. Catch transforms the condition P, trying to find the precondition on filterMP_isJust to ensure the postcondition P holds. By examining each branch, Catch determines that under all circumstances the postcondition will hold, therefore the precondition is just true. Given that filterMP_isJust always satisfies the precondition of map_fromJust, it is clear that joinMaybes never crashes.

Catch can generate the above proof automatically, showing the above function is safe.

1 comment:

Conal said...

Wonderful!