Notes for Math 221, Lecture 7, Sep 17 2009
Goal: continue with Matmul, the GE efficiently
Last time we proved:
Theorem: no matter how we reorganize matmul, as long as we
do the usual 2*n^3 arithmetic operations, but possibly
just doing the sums in different orders, then
#words_moved >= Omega(n^3/sqrt(M)) where M is the
fast memory size.
We also showed that "blocked matrix multiplication"
attained this lower bound
Here is another algorithm that also attains the lower bound,
that will use a pattern that will appear later: recursion
function C = RMM(A,B) ... recursive matrix multiplication
... assume for simplicity that A and B are square of size n
... assume for simplicity that n is a power of 2
if n=1
C = A*B ... scalar multiplication
else
... write A = [ A11 A12 ], where each Aij n/2 by n/2
[ A21 A22 ]
... write B and C similarly
C11 = RMM(A11,B11) + RMM(A12,B21)
C12 = RMM(A11,B12) + RMM(A12,B22)
C21 = RMM(A21,B11) + RMM(A22,B21)
C22 = RMM(A21,B12) + RMM(A22,B22)
endif
return
Correctness follows by induction on n
Cost: let A(n) = #arithmetic operation on matrices of dimension n
A(n) = 8*A(n/2) ... 8 recursive calls
+ n^2 ... 4 additions of n/2 x n/2 matrices
and A(1) = 1 is the base case
Solve by changing variables from n=2^m to m:
a(m) = 8*a(m-1) + 2^(2m), a(0) = 1
Divide by 8^m to get
a(m)/8^m = a(m-1)/8^(m-1) + (1/2)^m
Change variables again to b(m) = a(m)/8^m
b(m) = b(m-1) + (1/2)^m, b(0) = a(0)/8^0 = 1
= sum_{k=1 to m) (1/2)^k + b(0)
a simple geometric sum
b(m) = 2 - (1/2)^m -> 2
or
A(n) = a(log2 n) = b(log2 n)*8^(log2 n)
= b(log2 n)*n^3
-> 2n^3 as expected
Ask: Why isn't it exactly 2n^3?
A(n) = b(log2 n)*n^3
= (2 - 1/n)*n^3
= 2n^3 - n^2
Let W(n) = #words moved between slow memory and cache of size M
W(n) = 8*W(n/2) ... 8 recursive calls
+ 12(n/2)^2 ... up to 8 reads and 4 writes of n/2 x n/2 blocks
= 8*W(n/2) + 3n^2
oops - looks bigger than A(n), rather than A(n)/sqrt(M)
But what is base case?
W(b) = 3*b^2 if 3*b^2 <= M, i.e. b = sqrt(M/3)
The base case is not W(1) = something
Solve as before: change variables from n=2^m to m:
w(m) = 8*w(m-1) + 3*2^(2m), w(log2 sqrt(M/3)) = M
Let mbar = log2 sqrt(M/3)
Divide by 8^m to get v(m) = w(m)/8^m
v(m) = v(m-1) + 3*(1/2)^m,
v(mbar) = M/8^mbar = M/(M/3)^(3/2) = 3^(3/2)/M^(1/2)
This is again a geometric sum but just down to m = mbar, not 1
v(m) = sum_{k=mbar+1 to m} 3*(1/2)^k + v(mbar)
<= 3* (1/2)^(mbar) + v(mbar)
= 3/sqrt(M/3) + 3^(3/2)/sqrt(M)
= 2*3^(3/2)/sqrt(M)
so
W(n) = w(log2 n) = 8^(log2 n)*v(log2 n) <= 2*3^(3/2)*n^3/sqrt(M)
as desired.
Here is another application of the theorem to the parallel case,
where "load" means receive data from another processor's memory,
and "store" means send data to another processor's memory.
Suppose we have P processors working in parallel to multiply 2 matrices.
We would like each processor to do 1/P-th of the arithmetic,
i.e. 2n^3/P flops. We want each processor to store at most 1/P-th
of all the data (A, B and C), i.e. M = 3*n^2/P. Then the same proof
of our theorem says the number of load/stores is at least
Omega((2*n^3/P) / sqrt(3*n^2/P)) = Omega(n^2/sqrt(P))
Ask: Why does the proof work, even if a processor is only doing
part of the work in matmul?
Is there an algorithm that attains this? yes: Cannon's Algorithm
(picture of layout and processor grid)
(picture of algorithm)
assume s = sqrt(P) and n/s are integers
let each A(i,j) be an n/s x n/s subblock, numbered from A(0,0) to A(s-1,s-1)
(1) for all i=0 to s-1, left-circular-shift row i of A by i (parallel)
... so A(i,j) overwritten by A(i, i+j mod s)
(2) for all j=0 to s-1, up-circular-shift column j of B by j (parallel)
... so B(i,j) overwritten by B(i+j mod s, j)
for k= 0 to s-1 (sequential)
for all i=0 to s-1 and all j=0 to s-1 (all processors in parallel)
(3) C(i,j) = C(i,j) + A(i,j)*B(i,j)
(4) left-circular-shift each row of A by 1
(5) up-circular-shift each column of B by 1
... C(i,j) = C(i,j) + A(i, i+j+k mod s)*B(i+j+k mod s, j)
(Show slides)
Proof of correctness:
inner loop adds all C(i,j) = C(i,j) + A(i,k)*B(k,j),
just in "rotated" order
Arithmetic Cost: from line (3): s*(2*(n/s)^3) = 2*n^3/s^2 = 2*n^3/P,
so each processor does 1/P-th of the work as desired
Communication Cost (#words moved)
from line (1): s*(n/s)^2 = n^2/s words moved
from line (2): same
from line (4): s*(n/s)^2 = n^2/s words moved
from line (5): same
total = 4*n^2/s = 4*n^2/sqrt(P), meeting the lower bound
Let's consider more complicated timing model:
time = #flops*gamma + #words_moved*beta + #messages*alpha
where one message is a submatrix:
from each of lines (1), (2), (4), (5): s messages,
so 4*s = 4*sqrt(P) altogether
Since #words_moved <= #messages * maximum_message_size
a lower bound on #messages is
#messages >= #words_moved / maximum_message_size
Here we assume that each processor only stores O(1/P-th) of all data,
i.e. memory size is O(n^2/P). This is clearly an upper bound on
the maximum_message_size too, so
#messages >= #words_moved / O(n^2/P)
= Omega(n^2/sqrt(P)) / O(n^2/P) = sqrt(P)
and we see that Cannon's Alg also minimizes #messages (up to a constant factor)
Note: lots of assumptions (P a perfect square, s|n,
starting data in right place, double memory)
so practical algorithm (SUMMA) more complicated (in PBLAS library!)
sends log(P) times as much data as Cannon
One more matmul topic: can go asyptotically faster than n^3
Strassen (1967): O(n^2.81) = O(n^(log2 7))
Recursive, based on remarkable identities: for multiplying 2 x 2 matrices
Write A = [ A11 A12 ], each block of size n/2 x n/2, same for B and C
[ A21 A22 ]
P1 = (A12-A22)*(B21+B22)
P2 = (A11+A22)*(B11+B22)
P3 = (A11-A21)*(B11+B12)
P4 = (A11+A12)*B22
P5 = A11*(B12-B22)
P6 = A22*(B21-B11)
P7 = (A21+A22)*B11
C11 = P1+P2-P4+P6
C12 = P4+P5
C21 = P6+P7
C22 = P2-P3+P5-P7
Changes naturally into recursive algorithm:
function C = Strassen(A,B)
... assume for simplicity that A and B are square of size n
... assume for simplicity that n is a power of 2
if n=1
C = A*B ... scalar multiplication
else
P1 = Strassen(A12-A22,B21+B22) ... and 6 more similar lines
C11 = P1+P2-P4+P6 ... and 3 more similar lines
end
A(n) = #arithmetic operations
= 7*A(n/2) ... 7 recursive calls
+ 18*(n/2)^2 ... 18 additions of n/2 x n/2 matrices
solve as before: change variables from n=2^m to m, A(n) to a(m):
a(m) = 7*a(m-1) + (9/2)*2^(2m), a(0)=A(1)=1
divide by 7^m, change again to b(m) = a(m)/7^m
b(m) = b(m-1) + (9/2)*(4/7)^m, b(0)=a(0)=1
= O(1) ... a convergent geometric sequence
so
A(n) = a(log2 n) = b(log2 n)*7^(log2 n) = O(7^(log2 n)) = O(n^(log2 7))
Strassen is not (yet) used in practice
can pay off for n not too large (a few hundred)
error analysis slightly worse than usual algorithm:
Usual Matmul (Q1.10): |fl(A*B) - (A*B)| <= n*eps*|A|*|B|
Strassen ||fl(A*B) - (A*B)|| <= O(eps)*||A||*||B||
which is good enough for lots of purposes, but when can
Strassen's be much worse? (suppose one row of A, or one column of B, very tiny)
Not allowed to be used in "LINPACK Benchmark" (www.top500.org)
which ranks machines by speed = (2/3)*n^3 / time_for_GEPP(n)
and if you use Strassen (we'll see how) in GEPP, it does << (2/3)n^3 flops
so computed "speed" could exceed actually peak machine speed
As a result, vendors stopped optimizing Strassen, even though it is a good idea!
Current world's record
Coppersmith & Winograd (1990): O(n^2.376)
impractical: would need huge n to pay off
Rederived by Cohn, Umans, Kleinberg, Szegedy (2003,2005)
using group representation theory
All such algorithms shown numerically stable (in norm-sense)
by D., Dumitriu, Holtz, Kleinberg (2007)