In chapter 2, you saw that you’ll overflow the stack if you try to compose a huge number of functions. The reason is the same as for recursion: because composing functions results in methods calling methods. Having to compose more than 7,000 functions may be something you don’t expect to do soon. On the other hand, there’s no reason not to make it possible. If it’s possible, someone will eventually find something useful to do with it. And if it’s not useful, someone will certainly find something fun to do with it.
Let's write a function, composeAll, taking as its argument a list of functions from T to T and returning the result of composing all the functions in the list (Exercise 4.6). To get the result you want, you can use a right fold, taking as its arguments the list of functions, the identity function (obtained by a call to the statically imported Function.identity() method), and the compose method written in chapter 2:
- static <T> Function<T, T> composeAll(List<Function<T, T>> list) {
- return foldRight(list, Identity(), x -> y -> x.compose(y));
- }
- package test.fp.utils;
- import java.util.ArrayList;
- import java.util.List;
- import org.junit.Test;
- import fp.utils.Function;
- import static fp.utils.Function.*;
- import static org.junit.Assert.*;
- public class TestFunction {
- @Test
- public void testComposeAll() {
- List
> list = new ArrayList<>(); - for (int i = 0; i < 500; i++) {
- list.add(x -> x + 1);
- }
- int result = composeAll(list).apply(0);
- assertEquals(result, 500);
- }
- }
Let's fix this problem so you can compose an (almost) unlimited number of functions (Exercise 4.7). The solution to this problem is simple. Instead of composing the functions by nesting them, you have to compose their results, always staying at the higher level. This means that between each call to a function, you’ll return to the original caller. If this isn’t clear, imagine the imperative way to do this:
- T y = identity;
- for (Function<T, T> f : list) {
- y = f.apply(y);
- }
- static <T> Function<T, T> composeAll(List
> list) { - return x -> {
- T y = x; // A copy of x is made while it must be effectively final.
- for(Function<T,T> f:list)
- {
- y = f.apply(y);
- }
- return y;
- };
- }
- static <T> Function<T, T> composeAllViaFoldLeft(List<Function
> list) { - return x -> foldLeft(list, x, a -> b -> b.apply(a));
- }
- static <T> Function<T, T> composeAllViaFoldRight(List
> list) { - return x -> foldRight(list, x, a -> a::apply);
- }
- static <T> Function<T, T> composeAllViaFoldRight(List<Function<T, T>> list) {
- return x -> FoldRight.foldRight(list, x, a -> b -> a.apply(b));
- }
To see this in action, refer to the unit tests in the code available from the book’s site (https://github.com/fpinjava/fpinjava).
The code has two problems, and you fixed only one. Can you see another problem and fix it (Exercise 4.8)? The second problem isn’t visible in the result because the functions you’re composing are specific. They are, in fact, a single function from integer to integer. The order in which they’re composed is irrelevant. Try to use the composeAll method with the following function list:
- package fp.ch4;
- import fp.utils.Function;
- import static fp.utils.Function.*;
- import static fp.utils.CollectionUtilities.*;
- public class TestP22_1 {
- public static void main(String args[])
- {
- Function<String, String> f1 = x -> "(a" + x + ")";
- Function<String, String> f2 = x -> "{b" + x + "}";
- Function<String, String> f3 = x -> "[c" + x + "]";
- System.out.println(composeAllViaFoldLeft(list(f1, f2, f3)).apply("x"));
- System.out.println(composeAllViaFoldRight(list(f1, f2, f3)).apply("x"));
- }
- }
We’ve implemented andThenAll rather than composeAll! To get the correct result, you first have to reverse the list:
- static <T> Function<T, T> composeAllViaFoldLeft(List<Function<T, T>> list) {
- return x -> foldLeft(reverse(list), x, a -> b -> b.apply(a));
- }
- static <T> Function<T, T> composeAllViaFoldRight(List<Function<T, T>> list) {
- return x -> foldRight(list, x, a -> a::apply);
- }
- static <T> Function<T, T> andThenAllViaFoldLeft(List<Function<T, T>> list) {
- return x -> foldLeft(list, x, a -> b -> b.apply(a));
- }
- static <T> Function<T, T> andThenAllViaFoldRight(List<Function<T, T>> list) {
- return x -> foldRight(reverse(list), x, a -> a::apply);
- }
In previous section, you implemented a function to display a series of Fibonacci numbers. One problem with this implementation of the Fibonacci series is that you want to print the string representing the series up to f(n), which means you have to compute f(1), f(2), and so on, until f(n). But to compute f(n), you have to recursively compute the function for all preceding values. Eventually, to create the series up to n, you’ll have computed f(1) n times, f(2) n – 1 times, and so on. The total number of computations will then be the sum of the integers 1 to n. Can you do better? Could you possibly keep the computed values in memory so you don’t have to compute them again if they’re needed several times?
Memoization in imperative programming
In imperative programming, you wouldn’t even have this problem, because the obvious way to proceed would be as follows:
- package fp.ch4;
- import java.math.BigInteger;
- public class TestP23_1 {
- public static String fibo(int limit)
- {
- switch(limit)
- {
- case 0:
- return "0";
- case 1:
- return "0, 1";
- case 2:
- return "0, 1, 1";
- default:
- BigInteger fibo1 = BigInteger.ONE;
- BigInteger fibo2 = BigInteger.ONE;
- BigInteger fibonacci;
- StringBuilder builder = new StringBuilder("0, 1, 1");
- for(int i=3; i<=limit; i++)
- {
- fibonacci = fibo1.add(fibo2);
- builder.append(", ").append(fibonacci);
- fibo1 = fibo2; fibo2 = fibonacci;
- }
- return builder.toString();
- }
- }
- public static void main(String args[])
- {
- System.out.printf("Fibo(10)=%s\n", fibo(10));
- }
- }
This might seem incompatible with functional principles, because a memoized function maintains a state. But it isn’t, because the result of the function is the same when it’s called with the same argument. (You could even argue that it’s more the same, because it isn’t computed again!) The side effect of storing the results must not be visible from outside the function. In imperative programming, this might not even be noticed. Maintaining state is the universal way of computing results, so memoization isn’t even noticed.
Memoization in recursive functions
Recursive functions often use memoization implicitly. In your example of the recursive Fibonacci function, you wanted to return the series, so you calculated each number in the series, leading to unnecessary recalculations. A simple solution is to rewrite the function in order to directly return the string representing the series.
Let's Write a stack-safe tail recursive function taking an integer n as its argument and returning a string representing the values of the Fibonacci numbers from 0 to n, separated by a comma and a space (Exercise 4.9). One solution is to use StringBuilder as the accumulator. StringBuilder isn’t a functional structure because it’s mutable, but this mutation won’t be visible from the outside. Another solution is to return a list of numbers and then transform it into a String. This solution is easier, because you can abstract the problem of the separators by first returning a list and then writing a function to turn the list into a comma-separated string.
The following listing shows the solution using List as the accumulator:
Listing 4.5. Recursive Fibonacci with implicit memoization
- public static String fibo(int number)
- {
- List
list = fibo_(list(BigInteger.ZERO), BigInteger.ONE, BigInteger.ZERO, BigInteger.valueOf(number)).eval(); - return MakeString(list, ", ");
- }
- public static
TailCall - > fibo_(List
acc, BigInteger acc1, BigInteger acc2, BigInteger x) - {
- return x.equals(BigInteger.ZERO)
- ? ret(acc)
- : x.equals(BigInteger.ONE)
- ? ret(append(acc, acc1.add(acc2)))
- : sus(() -> fibo_(append(acc, acc1.add(acc2)), acc2, acc1.add(acc2), x.subtract(BigInteger.ONE)));
- }
- public static
String MakeString(List list, String sep) - {
- return list.isEmpty()
- ? ""
- : tail(list).isEmpty()
- ? head(list).toString()
- : head(list) + foldLeft(tail(list), "", x -> y -> x + sep + y);
- }
This example demonstrates the use of implicit memoization. Don’t conclude that this is the best way to solve the problem. Many problems are much easier to solve when twisted. So let’s twist this one. Instead of a suite of numbers, you could see the Fibonacci series as a suite of pairs (tuples). Instead of trying to generate this,
you could try to produce this:
In this series, each tuple can be constructed from the previous one. The second element of tuple n becomes the first element of tuple n + 1. The second element of tuple n + 1 is equal to the sum of the two elements of tuple n. In Java, you can write a function for this:
- x -> new Tuple<>(x._2, x._1.add(x._2));
- public static String fiboCorecursive(int number) {
- Tuple
seed = new Tuple<>(BigInteger.ZERO, BigInteger.ONE); - Function
, Tuple > f = x -> new Tuple<>(x._2, - x._1.add(x._2));
- List
list = map(fp.utils.List.iterate(seed, f, number), x -> x._1); - return MakeString(list, ", ");
- }
- public static List iterate(B seed, Function f, int n)
- package fp.utils;
- import java.util.ArrayList;
- public class List {
- public static
java.util.List int loop)iterate(T seed, Function fun, - {
- java.util.List
list = new ArrayList (); list.add(seed); - T tmp = seed;
- for(int i=0; i
1; i++) - {
- tmp = fun.apply(tmp); list.add(tmp);
- }
- return list;
- }
- }
Memoization isn’t mainly used for recursive functions. It can be used to speed up any function. Think about how you perform multiplication. If you need to multiply 234 by 686, you’ll probably need a pen and some paper, or a calculator. But if you’re asked to multiply 9 by 7, you can answer immediately, without doing any computation. This is because you use a memoized multiplication. A memoized function works the same way, although it needs to make the computation only once to retain the result.
Imagine you have a functional method doubleValue that multiplies its argument by 2:
- Integer doubleValue(Integer x) {
- return x * 2;
- }
- Map
cache = new ConcurrentHashMap<>(); - Integer doubleValue(Integer x)
- {
- if(cache.containsKey(x)) return cache.get(x);
- else
- {
- Integer rst = x * 2;
- cache.put(x, rst);
- return rst;
- }
- }
- Map
cache = new ConcurrentHashMap<>(); - Integer doubleValue(Integer x) {
- return cache.computeIfAbsent(x, y -> y * 2);
- }
- Function
doubleValue = - x -> cache.computeIfAbsent(x, y -> y * 2);
The second problem is easy to address. You can put the method or the function in a separate class, including the map, with private access. Here’s an example in the case of a method:
- public class Doubler {
- private static Map
cache = new ConcurrentHashMap<>(); - public static Integer doubleValue(Integer x) {
- return cache.computeIfAbsent(x, y -> y * 2);
- }
- }
- Integer y = Doubler.doubleValue(x);
- class Doubler {
- private static Map
cache = new ConcurrentHashMap<>(); - public static Function
doubleValue = - x -> cache.computeIfAbsent(x, y -> y * 2);
- }
- Integer y = Doubler.doubleValue.apply(x);
- map(range(1, 10), Doubler.doubleValue);
- map(range(1, 10), Doubler::doubleValue);
What you need is a way to do the following:
- Function
f = x -> x * 2; - Function
g = Memoizer.memoize(f);
Then you can use the memoized function as a drop-in replacement for the original one. All values returned by function g will be calculated through the original function f the first time, and returned from the cache for all subsequent accesses. By contrast, if you create a third function,
- Function
f = x -> x * 2; - Function
g = Memoizer.memoize(f); - Function
h = Memoizer.memoize(f);
the values cached by g won’t be returned by h; g and h will use separate caches.
Implementation
The Memoizer class is simple and is shown in the following listing.
- package fp.utils;
- import java.util.Map;
- import java.util.concurrent.ConcurrentHashMap;
- public class Memorizer
{ - private final Map
cache = new ConcurrentHashMap<>(); - private Memorizer(){}
- public static
Function memorize(Function fun) - {
- return new Memorizer
().doMemorize(fun); - }
- private Function
doMemorize(Function fun) - {
- return input -> cache.computeIfAbsent(input, fun::apply);
- }
- }
Listing 4.7. Demonstrating the memoizer
- package fp.ch4;
- public class List4_7 {
- private static Integer longCalculation(Integer x)
- {
- try
- {
- Thread.sleep(1000);
- }
- catch(InterruptedException ignored){}
- return x * 2;
- }
- private static Function
f = List4_7::longCalculation; - private static Function
g = Memorizer.memorize(f); - public static void main(String args[])
- {
- long startTime = System.currentTimeMillis();
- Integer result1 = g.apply(1);
- long time1 = System.currentTimeMillis() - startTime;
- startTime = System.currentTimeMillis();
- Integer result2 = g.apply(1);
- long time2 = System.currentTimeMillis() - startTime;
- System.out.println(result1);
- System.out.println(result2);
- System.out.println(time1);
- System.out.println(time2);
- }
- }
Note that the exact result will depend on the speed of your computer. You can now make memoized functions out of ordinary ones by calling a single method, but to use this technique in production, you’d have to handle potential memory problems. This code is acceptable if the number of possible inputs is low, so you can keep all results in memory without causing memory overflow. Otherwise, you can use soft references or weak references to store memoized values.
Memoization of “multiargument” functions
As I said before, there’s no such thing in this world as a function with several arguments. Functions are applications of one set (the source set) to another set (the target set). They can’t have several arguments. Functions that appear to have several arguments are one of these:
In either case, you’re concerned only with functions of one argument, so you can easily use your Memoizer class. Using functions of tuples is probably the simplest choice. You could use the Tuple class written in previous chapters, but to store tuples in maps, you’d have to implement equals and hashcode. Besides this, you’d have to define tuples for two elements (pairs), tuples for three elements, and so on. Who knows where to stop?
The second option is much easier. You have to use the curried version of the functions, as you did in previous chapters. Memoizing curried functions is easy, although you can’t use the same simple form as previously. You have to memoize each function:
- Function
> mhc = - Memoizer.memoize(x ->
- Memoizer.memoize(y -> x + y));
- Function
>> f3 = - x -> y -> z -> x + y - z;
- Function
>> f3m = - Memoizer.memoize(x ->
- Memoizer.memoize(y ->
- Memoizer.memoize(z -> x + y - z));
Listing 4.8. Testing a memoized function of three arguments for performance
- package fp.ch4;
- import fp.utils.Function;
- import fp.utils.Memorizer;
- public class List4_8 {
- public static Integer longCalculation(Integer x)
- {
- try
- {
- Thread.sleep(1000);
- }
- catch(InterruptedException ignored){}
- return x * 2;
- }
- public static void main(String[] args) {
- Function
>> f3m = - Memorizer.memorize(x ->
- Memorizer.memorize(y ->
- Memorizer.memorize(z ->
- longCalculation(x) + longCalculation(y) - longCalculation(z))));
- long startTime = System.currentTimeMillis();
- Integer result1 = f3m.apply(2).apply(3).apply(4);
- long time1 = System.currentTimeMillis() - startTime;
- startTime = System.currentTimeMillis();
- Integer result2 = f3m.apply(2).apply(3).apply(4);
- long time2 = System.currentTimeMillis() - startTime;
- System.out.println(result1);
- System.out.println(result2);
- System.out.println(time1);
- System.out.println(time2);
- }
- }
This shows that the first access to the longCalculation method has taken 3,000 milliseconds, and the second has returned immediately. On the other hand, using a function of a tuple may seem easier after you have the Tuple class defined. The following listing shows an example of Tuple3.
Listing 4.9. An implementation of Tuple3
- package fp.utils;
- import java.util.Objects;
- public class Tuple3<T, U, V> {
- public final T _1;
- public final U _2;
- public final V _3;
- public Tuple3(T t, U u, V v) {
- _1 = Objects.requireNonNull(t);
- _2 = Objects.requireNonNull(u);
- _3 = Objects.requireNonNull(v);
- }
- @Override
- public boolean equals(Object o) {
- if (!(o instanceof Tuple3))
- return false;
- else {
- Tuple3 that = (Tuple3) o;
- return _1.equals(that._1) && _2.equals(that._2) && _3.equals(that._3);
- }
- }
- @Override
- public int hashCode() {
- final int prime = 31;
- int result = 1;
- result = prime * result + _1.hashCode();
- result = prime * result + _2.hashCode();
- result = prime * result + _3.hashCode();
- return result;
- }
- }
Listing 4.10. A memoized function of Tuple3
- package fp.ch4;
- import fp.utils.Function;
- import fp.utils.Memorizer;
- import fp.utils.Tuple3;
- public class List4_10 {
- public static Integer longCalculation(Integer x)
- {
- try
- {
- Thread.sleep(1000);
- }
- catch(InterruptedException ignored){}
- return x * 2;
- }
- public static void main(String[] args) {
- Function<Tuple3<Integer, Integer, Integer>, Integer> ft =
- x -> longCalculation(x._1)
- + longCalculation(x._2)
- - longCalculation(x._3);
- Function<Tuple3<Integer, Integer, Integer>, Integer> ftm = Memorizer.memorize(ft);
- long startTime = System.currentTimeMillis();
- Integer result1 = ftm.apply(new Tuple3<>(2, 3, 4));
- long time1 = System.currentTimeMillis() - startTime;
- startTime = System.currentTimeMillis();
- Integer result2 = ftm.apply(new Tuple3<>(2, 3, 4));
- long time2 = System.currentTimeMillis() - startTime;
- System.out.println(result1);
- System.out.println(result2);
- System.out.println(time1);
- System.out.println(time2);
- }
- }
Memoizing is about maintaining state between function calls. A memoized function is a function whose behavior is dependent on the current state. But it’ll always return the same value for the same argument. Only the time needed to return the value will be different. So the memoized function is still a pure function if the original function is pure. A variation in time may be a problem. A function like the original Fibonacci function needing many years to complete may be called nonterminating, so an increase in time may create a problem. On the other hand, making a function faster shouldn’t be a problem. If it is, there’s a much bigger problem somewhere else!
Summary
Supplement
* Ch4 - Recursion, corecursion, and memoization - Part1
* Ch4 - Recursion, corecursion, and memoization - Part2
沒有留言:
張貼留言