Memoization & dynamic programming

Memoization

A way to trade space for time and make recursion more efficient.

The basic idea

If our recursive function is pure (i.e. it’s value depends only on the values of its immutable arguments) we don’t need to recompute values if we’ve computed them before.

So the first time we compute our function for a particular set of arguments, we remember (memoize) the value so we don’t have to compute it again later.

A simple, general approach

Make a Map to store already computed values with the arguments as the key.

Each time the function is called, check the map to see if the value for that particular set of arguments has already been computed.

If it hasn’t compute it and add it to the map.

Return the value from the map, either the one that was already there or the one we just added.

Memoized fibonacci

private static Map<Integer, Long> memo = new HashMap<>();

public static long fibonacci(int n) {
  if (n < 2) {
     return n;
  } else {
    if (!memo.containsKey(n)) {
      memo.put(n, fibonacci(n - 1) + fibonacci(n - 2));
    }
    return memo.get(n);
  }
}

Making a key

When there’s only one argument, it can be the key.

But if you are memoizing a function with multiple arguments, you need to combine them into a single key.

Records are great for this.

As long as the arguments are all unchanging values with a good equals and hashCode methods, a record containing those args will work as a key.

Using a record key

private record Key(int n) {}

private static Map<Key, Long> memo = new HashMap<>();

public static long fibonacci(int n) {
  if (n < 2) {
     return n;
  } else {
    Key k = new Key(n);
    if (!memo.containsKey(k)) {
      memo.put(k, fibonacci(n - 1) + fibonacci(n - 2));
    }
    return memo.get(k);
  }
}

Some choices

Do you make a single map that lives forever, which can be used whenever the function is called?

Or do you define the actual computation in terms of a helper function that takes the map as an argument and then define the actual function to create a new map and then call the helper?

What are the trade-offs?

Temporary memo table

public static long fibonacci(int n) {
  return memoFib(n, new HashMap<>(Map.of(0, 0L, 1, 1L)));
}

public static long memoFib(int n, Map<Integer, Long> memo) {
  if (!memo.containsKey(n)) {
    memo.put(n, memoFib(n - 1, memo) + memoFib(n - 2, memo));
  }
  return memo.get(n);
}

Make sure in your helper method you recurse using the helper method, not the top-level method.

Less general approach

Sometimes there’s a lighter weight data structure to use for memoization.

Array memoized fibonacci

public static long fibonacci(int n) {
  return memoFib(n, new long[n + 1]);
}

private static long memoFib(int n, long[] memo) {
  if (n < 2) {
    return n;
  } else {
    if (memo[n] == 0) {
      memo[n] = memoFib(n - 1, memo) + memoFib(n - 2, memo);
    }
    return memo[n];
  }
}

Dynamic programming

Memoization still relies on recursion and thus still requires a call stack as deep as the path from the starting problem to the base cases.

Memoizing just saves us from going back down paths of the tree that we’ve already computed.

Dynamic programming is a technique we can use if we can describe how to compute the values from the bottom up.

Dynamic programming fibonacci

public static long fibonacci(int n) {
  var table = new long[n + 1];
  table[1] = 1;
  for (int i = 2; i < table.length; i++) {
    table[i] = table[i - 1] + table[i - 2];
  }
  return table[n];
}

Optimizing space usage

Once we’ve organized our computation this way, sometimes we can then see that we don’t actually need to keep all the old answers.

In fibonacci, once we’ve computed a given table entry we don’t need any more than the previous entry to compute the subsequent entries.

So we can often run in constant space by taking advantage of that fact.

Space optimized dynamic programming fibonacci

public static long fibonacci(int n) {
  var table = new long[] { 0, 1, 0 };
  for (int i = 2; i <= n; i++) {
    table[i % 3] = table[(i - 1) % 3] + table[(i - 2) % 3];
  }
  return table[n % 3];
}

This is basically the same as our old iterative fibonacci just with a single array rather than separate variables.

Using % on the indices allows the actual indices to cycle around the array.