Archive for March, 2010

Using Monads for State

March 30, 2010

One of the things that monads have been very useful for is State. Most of the time there is some intermediate structure that is not necessarily part of the result however it is needed in the intermediate calculation in order to come out with the final output. I was very lucky in finding Wadler’s Paper that explains monads by examples, which inspired me to create my own. Conveniently enough, I remembered a function that I wrote a while back for Project Euler where every number that is not the sum of two abundants is added up. The function takes as input a number to check as well as a list of all its known abundant numbers so far.

accumIfNotSumOfTwoAbundants (result,abundants) num
| isAbundant num == True  && isSumOfTwoAbundants abundants num == True  = 
| isAbundant num == False && isSumOfTwoAbundants abundants num == True  = 
| isAbundant num == True  && isSumOfTwoAbundants abundants num == False = 
| isAbundant num == False && isSumOfTwoAbundants abundants num == False = 

The function has a "side effect", in that it checks if the passed number is abundant before checking if the number is the sum of two previously known abundants. If the number itself is abundant, the number needs to be added to the currently passed list of abundants for the next round of calculation. Along with that, the result must be passed, which is the accumulated result so far plus the number passed if the number is sum of two abundants, otherwise only the accumulated result. So, how do we go about finding the result? In order to gain an appreciation of monads, let's not use folds but instead write the recursive calls manually, which will look something like

main = let (a,x) = accumIfNotSumOfTwoAbundants (0,[]) 12
           (b,y) = accumIfNotSumOfTwoAbundants (a,x)  13
           (c,z) = accumIfNotSumOfTwoAbundants (b,y)  14
        in putStrLn (show (c,z))

where c is the result and z is the intermediate list. Now the question is "can we abstract the list in a type since it is essentially global state", and the answer is yes we can! The inputs to the original function need to be slightly modified so that the list "state" is the last parameter being passed rather than being passed as a tuple along with the result. Let's define a new type now

type StateTrans s a = s -> (a,s)

As guessed, s is a generic state and a is the result. Now is the time to define the two famous functions, composition and return.

(>>>=) :: StateTrans s a -> (a -> StateTrans s b) -> StateTrans s b
p >>>= k = \s0 -> let (a,s1) = p s0
                      (b,s2) = k a s1
                   in (b,s2)

unit :: a -> StateTrans s a
unit a = \s -> (a,s)

Notice that I have defined composition as >>>= rather than >>= and return as unit to avoid compiler confusion. Now comes the main driver that pipes everything together.

main = let eval   = accumIfNotSumOfTwoAbundants 0 12 >>>=
             \a  -> accumIfNotSumOfTwoAbundants a 13 >>>=
             \b  -> accumIfNotSumOfTwoAbundants b 14 >>>=
             \c  -> unit c
        in putStrLn . show $ eval []

As noticed, the intermediate list does not need to be passed around every single time, only the result. The initial empty list is passed only once to the function eval. However, the function eval does not take any parameters, so how can we pass an extra parameter? Due to the magic of partial application it turns out that when we call accumIfNotSumOfTwoAbundants 0 12, it is actually a function with type [Integer] -> (Integer, [Integer]). Wow, that looks exactly like our type that we defined earlier! Well, it is just so convenient that it is the first parameter of input of >>>= function as well. That means that p in this case is really the expression accumIfNotSumOfTwoAbundants 0 12. With the combined power of lazy evaluation and partial application s0 does not need to be extracted right away and is instead a thunk that will be evaluated later. So then k a is really the anonymous function \a -> accumIfNotSumOfTwoAbundants a 13. Notice that here a is the result of the first function which is the first parameter to the anonymous function as well. Here, s1 is the intermediate state that came as a result of calculating the first expression, which is then passed to k a, which in turn becomes (\a -> accumIfNotSumOfTwoAbundants a 13) s1. As can be observed, multiple functions may be strung together in such a fashion, in a very similar fashion to unix pipes. It is time to talk about the function unit. It is not necessary in the following example and writing main without unit call would yield same result as with it. However it is good practice to have it since you know exactly the value that you are returning. Interesting to note is that unit may return an intermediate value as well such as

main = let eval   = accumIfNotSumOfTwoAbundants 0 12 >>>=
             \a  -> accumIfNotSumOfTwoAbundants a 13 >>>=
             \b  -> accumIfNotSumOfTwoAbundants b 14 >>>=
             \c  -> unit a
        in putStrLn . show $ eval []

The main thing to realize is that unit takes as a parameter the intermediate, and raises it to the monad which in this case is (result,state). From here on out an actual instance of a monad may be used so that main may be rewritten in a do style notation.

In order to create an instance of Monad, we will have to create a newtype and cannot use type synonym anymore.

newtype StateTrans s a = ST( s -> (a,s) )

The new monad instance will look such as

instance Monad (StateTrans s)
    (ST p) >>= k = ST( \s0 -> let (a,s1) = p s0
                                  (ST q) = k a
                              in q s1 )
    return a = ST( \s -> (a,s) ) 

The ST function essentially acts as a lifting function which transforms into StateTrans type. Another function is necessary, applyST, which drops the result into something usable by other functions.

applyST :: StateTrans s a -> s -> (a,s)
applyST (ST p) s = p s            

Now, we are able to write the main portion of our code in a style that is similar to C syntax.

main = let eval = do 
                a <- ST (accumIfNotSumOfTwoAbundants 0 12);
                b <- ST (accumIfNotSumOfTwoAbundants a 13);
                c <- ST (accumIfNotSumOfTwoAbundants b 14);
                return c
       in putStrLn . show $ applyST eval []

Reference: Haskell and Monads