Friday, April 4, 2014

Memoization

Caching

When programming, we sometimes get to a point, where a program runs slow, because the computer performs the same computation over and over. To understand when caching can make sense to make your program faster, let us first be a little more precise about what it really means, and when it can be used.

First and foremost caching means, that before runnning a function, that may take a large amount of time, we check in a lookup table e.g. a dictionary, if the result has been computed previously and thus need not be computed again. Instead, the result is immediately returned from the lookup table.

In order for caching to make sense, we need the function in question to satisfy certain criteria:

  • It should be referentially transparent, i.e. it has to return the same result for the same input, for at least as long as the cache is valid. Good examples would be mathematical functions like prime number checking, hashing, etc. but also things like reading the contents of a text file into the memory, if we know the files is not going to change for the lifetime of the cache. Typical examples of functions, that do not satisfy this property would be a random number generator, getting the current balance from your account, etc.
  • It should take long, this is not a precise definition, but since we add additional complexity, like lookup from a table, etc. we make the individual call slower. So we need to gain enough by caching for this to be justified.

Memoization

Memoization extends the idea of a cache from single values to whole functions. While it may at first sound like just a detail for performance tuning, it can have, in fact, a deep impact on the time and space complexity of algorithms. Consider this simple function:

1: 
2: 
3: 
let rec fib = function
    | 0 | 1 -> 1
    | n -> fib (n-1) + fib (n-2)

the function as it is written here has a time complexity of \(O(2^n)\), because for each value it branches twice.

On the other hand, every value \(fib(n)\) only depends of the values of \(fib(i)\) for \(i < n\).

Now what would happen, if \(fib\) was memoized, i.e. each value would be computed at most once.

To calculate the value \(fib(n)\), we need to calculate the value \(fib(n-1)\) which in turn depends of \(fib(n-2)\). That means, that by the time we are done calculating \(fib(n-1)\), we already have \(fib(n-2)\) in the cache. And so, we do not branch in this case, but immediately return the result.

Well, it turns out, that the algorithm all of the sudden becomes \(O(n)\). We traded time complexity for some space complexity.

Implementing Memoization in F#

We can implement memoization using a mutable dictionary of some sort. A purely function implementation without a mutable dictionary, I might explore in a later blog post.

We need a cache and a way to get or add the a function value. And our memoized function can be written thusly

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let memo f =
    let dict = Dictionary()
    fun x -> match dict.TryGetValue x with
             | true, value -> value
             | _ -> let value = f x
                    dict.Add(x, value)
                    value

this implementation, however has several caveats:

  • it explodes, if you use null (or ()) as the function argument
  • it is not thread-safe; the operation Add explodes, if the key is already present.

We can do better:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let memo f =
    let dict = Dictionary()
    fun x -> match dict.TryGetValue (Some x) with
             | true, value -> value
             | _ -> let value = f x
                    dict.[Some x] <- value
                    value

this implementation is better. Instead does not explode with null input and we might call it semi thread-safe. This means, that it won’t explode in case of a race condition, but it might calculate unnecessarily and then overwrite the cache with the value. On the other hand, it works without locks.

Supporting recursive functions

Unfortunately, this approach does not work for recursive functions, because the memoization is only added to the outermost function, but internally the function calls the non-memoized version. A very elegant solution, is to use a memoizing \(Y\)-combinator. I shall explore the \(Y\)-combinator in a later post. In short, it is a way of writing a recursive function without recursion. Instead, another argument is added, that corresponds to the recursive definition.

With this approach our \(fib\)-function now becomes

1: 
2: 
3: 
let fib fib = function
    | 0 | 1 -> 1
    | n -> fib (n-1) + fib (n-2)

Here, I used the same name as the function for the recursive argument, so that the code still looks like a recursive function.

To memoize all the versions of \(fib\), we need to be able to reuse the cache. So we really need two things: create a cache and then use this cache in a getOrCache function, that can either get the value or compute it and cache it for later use. Both things can be implemented with a simple function:

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
open System.Collections.Generic
let createCache () =
    let dict = Dictionary()
    fun f x -> match dict.TryGetValue (Some x) with
               | true, value -> value
               | _ -> let value = f x
                      dict.[Some x] <- value
                      value

our memo function then becomes

1: 
2: 
3: 
let memo cache f =
    let cache = cache()
    cache f

and the memoizing \(Y\)-combinator looks thus:

1: 
2: 
3: 
4: 
let memoFix cache f =
    let cache = cache()
    let rec fn x = cache (f fn) x
    fn

the point being, that we reuse the same cache for all the recursive versions of f.

Different kinds of caches

We can use any kind of cache, as long as we can fulfil the function signature for createCache. For instance we can use the truly thread-safe ConcurrentDictionary instead.

1: 
2: 
3: 
4: 
open System.Collections.Concurrent
let createCache () =
    let dict = ConcurrentDictionary()
    fun f x -> dict.GetOrAdd(Some x, lazy(f x)).Value

We can now create a small module for each kind of cache, so that we do not need to pass the createCache function each time.

Multiple curried arguments

And lastly, we might want to be able to memoize more than just one single argument. For tupled arguments, we already have the feature, and for curried arguments, we can just call memo repeatedly.

1: 
2: 
3: 
4: 
5: 
6: 
7: 
8: 
let memo f = memo createCache f
let memo2 f = memo (memo << f)
let memo3 f = memo (memo2 << f)
let memo4 f = memo (memo3 << f)
let memo5 f = memo (memo4 << f)
let memo6 f = memo (memo5 << f)
let memo7 f = memo (memo6 << f)
let memo8 f = memo (memo7 << f)

the whole snippet is available under on fssnip.net

val fib : _arg1:int -> int

Full name: memo.fib
val n : int
namespace System
namespace System.Collections
namespace System.Collections.Generic
val memo : f:('a -> 'b) -> ('a -> 'b) (requires equality)

Full name: memo.memo
val f : ('a -> 'b) (requires equality)
val dict : Dictionary<'a,'b> (requires equality)
Multiple items
type Dictionary<'TKey,'TValue> =
  new : unit -> Dictionary<'TKey, 'TValue> + 5 overloads
  member Add : key:'TKey * value:'TValue -> unit
  member Clear : unit -> unit
  member Comparer : IEqualityComparer<'TKey>
  member ContainsKey : key:'TKey -> bool
  member ContainsValue : value:'TValue -> bool
  member Count : int
  member GetEnumerator : unit -> Enumerator<'TKey, 'TValue>
  member GetObjectData : info:SerializationInfo * context:StreamingContext -> unit
  member Item : 'TKey -> 'TValue with get, set
  ...
  nested type Enumerator
  nested type KeyCollection
  nested type ValueCollection

Full name: System.Collections.Generic.Dictionary<_,_>

--------------------
Dictionary() : unit
Dictionary(capacity: int) : unit
Dictionary(comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>) : unit
Dictionary(capacity: int, comparer: IEqualityComparer<'TKey>) : unit
Dictionary(dictionary: IDictionary<'TKey,'TValue>, comparer: IEqualityComparer<'TKey>) : unit
val x : 'a (requires equality)
Dictionary.TryGetValue(key: 'a, value: byref<'b>) : bool
val value : 'b
Dictionary.Add(key: 'a, value: 'b) : unit
val dict : Dictionary<'a option,'b> (requires equality)
Dictionary.TryGetValue(key: 'a option, value: byref<'b>) : bool
union case Option.Some: Value: 'T -> Option<'T>
val fib : fib:(int -> int) -> _arg1:int -> int

Full name: memo.fib
val fib : (int -> int)
val createCache : unit -> (('a -> 'b) -> 'a -> 'b) (requires equality)

Full name: memo.createCache
val memo : cache:(unit -> 'a -> 'b) -> f:'a -> 'b

Full name: memo.memo
val cache : (unit -> 'a -> 'b)
val f : 'a
val cache : ('a -> 'b)
val memoFix : cache:(unit -> 'a -> 'b -> 'c) -> f:(('b -> 'c) -> 'a) -> ('b -> 'c)

Full name: memo.memoFix
val cache : (unit -> 'a -> 'b -> 'c)
val f : (('b -> 'c) -> 'a)
val cache : ('a -> 'b -> 'c)
val fn : ('b -> 'c)
val x : 'b
namespace System.Collections.Concurrent
val createCache : unit -> (('a -> 'b) -> 'a -> 'b)

Full name: memo.createCache
val dict : ConcurrentDictionary<'a option,Lazy<'b>>
Multiple items
type ConcurrentDictionary<'TKey,'TValue> =
  new : unit -> ConcurrentDictionary<'TKey, 'TValue> + 6 overloads
  member AddOrUpdate : key:'TKey * addValueFactory:Func<'TKey, 'TValue> * updateValueFactory:Func<'TKey, 'TValue, 'TValue> -> 'TValue + 1 overload
  member Clear : unit -> unit
  member ContainsKey : key:'TKey -> bool
  member Count : int
  member GetEnumerator : unit -> IEnumerator<KeyValuePair<'TKey, 'TValue>>
  member GetOrAdd : key:'TKey * valueFactory:Func<'TKey, 'TValue> -> 'TValue + 1 overload
  member IsEmpty : bool
  member Item : 'TKey -> 'TValue with get, set
  member Keys : ICollection<'TKey>
  ...

Full name: System.Collections.Concurrent.ConcurrentDictionary<_,_>

--------------------
ConcurrentDictionary() : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>) : unit
ConcurrentDictionary(comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int) : unit
ConcurrentDictionary(collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, collection: IEnumerable<KeyValuePair<'TKey,'TValue>>, comparer: IEqualityComparer<'TKey>) : unit
ConcurrentDictionary(concurrencyLevel: int, capacity: int, comparer: IEqualityComparer<'TKey>) : unit
val f : ('a -> 'b)
val x : 'a
ConcurrentDictionary.GetOrAdd(key: 'a option, value: Lazy<'b>) : Lazy<'b>
ConcurrentDictionary.GetOrAdd(key: 'a option, valueFactory: System.Func<'a option,Lazy<'b>>) : Lazy<'b>
val memo : f:('a -> 'b) -> ('a -> 'b)

Full name: memo.memo
val memo2 : f:('a -> 'b -> 'c) -> ('a -> 'b -> 'c)

Full name: memo.memo2
val f : ('a -> 'b -> 'c)
val memo3 : f:('a -> 'b -> 'c -> 'd) -> ('a -> 'b -> 'c -> 'd)

Full name: memo.memo3
val f : ('a -> 'b -> 'c -> 'd)
val memo4 : f:('a -> 'b -> 'c -> 'd -> 'e) -> ('a -> 'b -> 'c -> 'd -> 'e)

Full name: memo.memo4
val f : ('a -> 'b -> 'c -> 'd -> 'e)
val memo5 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f)

Full name: memo.memo5
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f)
val memo6 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g)

Full name: memo.memo6
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g)
val memo7 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h)

Full name: memo.memo7
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h)
val memo8 : f:('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i) -> ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i)

Full name: memo.memo8
val f : ('a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'g -> 'h -> 'i)

No comments:

Post a Comment