Notes for Math 221, Lecture 7, Sep 17 2009

Goal: continue with Matmul, the GE efficiently

   Last time we proved:

        Theorem: no matter how we reorganize matmul, as long as we
                 do the usual 2*n^3 arithmetic operations, but possibly
                 just doing the sums in different orders, then
                 #words_moved >= Omega(n^3/sqrt(M)) where M is the
                 fast memory size.

        We also showed that "blocked matrix multiplication" 
        attained this lower bound

        Here is another algorithm that also attains the lower bound,
        that will use a pattern that will appear later: recursion

       function C = RMM(A,B)  ... recursive matrix multiplication
           ... assume for simplicity that A and B are square of size n
           ... assume for simplicity that n is a power of 2
           if n=1 
              C = A*B  ... scalar multiplication
           else
              ... write A = [ A11  A12 ], where each Aij n/2 by n/2
                            [ A21  A22 ]
              ... write B and C similarly
              C11 = RMM(A11,B11) + RMM(A12,B21)
              C12 = RMM(A11,B12) + RMM(A12,B22)
              C21 = RMM(A21,B11) + RMM(A22,B21)
              C22 = RMM(A21,B12) + RMM(A22,B22)
           endif
           return

       Correctness follows by induction on n
       Cost:  let A(n) = #arithmetic operation on matrices of dimension n
              A(n) = 8*A(n/2)   ... 8 recursive calls
                     + n^2      ... 4 additions of n/2 x n/2 matrices
              and A(1) = 1 is the base case

              Solve by changing variables from n=2^m to m:
                a(m) = 8*a(m-1) + 2^(2m),    a(0) = 1
              Divide by 8^m to get 
                a(m)/8^m = a(m-1)/8^(m-1) + (1/2)^m
              Change variables again to b(m) = a(m)/8^m
                b(m) = b(m-1) + (1/2)^m,   b(0) = a(0)/8^0 = 1
                     = sum_{k=1 to m) (1/2)^k + b(0)
              a simple geometric sum
                b(m) = 2 - (1/2)^m  -> 2
              or
                A(n) = a(log2 n) = b(log2 n)*8^(log2 n) 
                                 = b(log2 n)*n^3
                                 -> 2n^3 as expected
              Ask: Why isn't it exactly 2n^3?
                   A(n) = b(log2 n)*n^3
                        = (2 - 1/n)*n^3
                        = 2n^3 - n^2

             Let W(n) = #words moved between slow memory and cache of size M
             W(n) = 8*W(n/2)   ... 8 recursive calls
                    + 12(n/2)^2  ... up to 8 reads and 4 writes of n/2 x n/2 blocks
                  = 8*W(n/2) + 3n^2
              oops - looks bigger than A(n), rather than A(n)/sqrt(M)
              But what is base case? 
                  W(b) = 3*b^2 if 3*b^2 <= M,  i.e. b = sqrt(M/3)
              The base case is not W(1) = something
              Solve as before: change variables from n=2^m to m:
                  w(m) = 8*w(m-1) + 3*2^(2m),   w(log2 sqrt(M/3)) = M
              Let mbar = log2 sqrt(M/3)
              Divide by 8^m to get v(m) = w(m)/8^m
                  v(m) = v(m-1) + 3*(1/2)^m,    
                      v(mbar) = M/8^mbar = M/(M/3)^(3/2) = 3^(3/2)/M^(1/2)
              This is again a geometric sum but just down to m = mbar, not 1
                  v(m) = sum_{k=mbar+1 to m} 3*(1/2)^k + v(mbar)
                       <= 3* (1/2)^(mbar) + v(mbar)
                        = 3/sqrt(M/3) + 3^(3/2)/sqrt(M)
                        = 2*3^(3/2)/sqrt(M)
              so 
                  W(n) = w(log2 n) = 8^(log2 n)*v(log2 n) <= 2*3^(3/2)*n^3/sqrt(M)
              as desired.

     Here is another application of the theorem to the parallel case,
     where "load" means receive data from another processor's memory,
     and "store" means send data to another processor's memory.

     Suppose we have P processors working in parallel to multiply 2 matrices.
     We would like each processor to do 1/P-th of the arithmetic,
     i.e. 2n^3/P flops. We want each processor to store at most 1/P-th
     of all the data (A, B and C), i.e. M = 3*n^2/P. Then the same proof
     of our theorem says the number of load/stores is at least
     Omega((2*n^3/P) / sqrt(3*n^2/P)) = Omega(n^2/sqrt(P))

     Ask: Why does the proof work, even if a processor is only doing
     part of the work in matmul?

     Is there an algorithm that attains this? yes: Cannon's Algorithm 
     (picture of layout and processor grid)
     (picture of algorithm)
         assume s = sqrt(P) and n/s are integers
          let each A(i,j) be an n/s x n/s subblock, numbered from A(0,0) to A(s-1,s-1)
(1)       for all i=0 to s-1, left-circular-shift row i of A by i (parallel)
             ... so A(i,j) overwritten by A(i, i+j mod s)
(2)       for all j=0 to s-1, up-circular-shift column j of B by j (parallel)
             ... so B(i,j) overwritten by B(i+j mod s, j)
          for k= 0 to s-1 (sequential)
             for all i=0 to s-1 and all j=0 to s-1  (all processors in parallel)
(3)             C(i,j) = C(i,j) + A(i,j)*B(i,j)
(4)             left-circular-shift each row of A by 1
(5)             up-circular-shift each column of B by 1
                ... C(i,j) = C(i,j) + A(i, i+j+k mod s)*B(i+j+k mod s, j)
     (Show slides)
     Proof of correctness:
         inner loop adds all C(i,j) = C(i,j) + A(i,k)*B(k,j), 
           just in "rotated" order
     
     Arithmetic Cost: from line (3): s*(2*(n/s)^3) = 2*n^3/s^2 = 2*n^3/P,
         so each processor does 1/P-th of the work as desired
     Communication Cost (#words moved)
         from line (1): s*(n/s)^2 = n^2/s words moved
         from line (2): same
         from line (4): s*(n/s)^2 = n^2/s words moved
         from line (5): same
         total = 4*n^2/s = 4*n^2/sqrt(P),  meeting the lower bound

     Let's consider more complicated timing model:
         time = #flops*gamma + #words_moved*beta + #messages*alpha
         where one message is a submatrix:
         from each of lines (1), (2), (4), (5): s messages, 
         so 4*s = 4*sqrt(P) altogether

     Since #words_moved <= #messages * maximum_message_size
     a lower bound on #messages is
           #messages >= #words_moved / maximum_message_size
     Here we assume that each processor only stores O(1/P-th) of all data,
     i.e. memory size is O(n^2/P). This is clearly an upper bound on
     the maximum_message_size too, so
          #messages >= #words_moved / O(n^2/P)
                     = Omega(n^2/sqrt(P)) / O(n^2/P) = sqrt(P)
     and we see that Cannon's Alg also minimizes #messages  (up to a constant factor)

     Note: lots of assumptions (P a perfect square, s|n, 
           starting data in right place, double memory)
           so practical algorithm (SUMMA) more complicated (in PBLAS library!)
             sends log(P) times as much data as Cannon

     One more matmul topic: can go asyptotically faster than n^3

     Strassen (1967): O(n^2.81) = O(n^(log2 7))

     Recursive, based on remarkable identities: for multiplying 2 x 2 matrices
        Write A = [ A11 A12 ], each block of size n/2 x n/2, same for B and C
                  [ A21 A22 ]
        P1 = (A12-A22)*(B21+B22)
        P2 = (A11+A22)*(B11+B22)
        P3 = (A11-A21)*(B11+B12)
        P4 = (A11+A12)*B22
        P5 = A11*(B12-B22)
        P6 = A22*(B21-B11)
        P7 = (A21+A22)*B11
        C11 = P1+P2-P4+P6
        C12 = P4+P5
        C21 = P6+P7
        C22 = P2-P3+P5-P7

        Changes naturally into recursive algorithm:
           function C = Strassen(A,B)
           ... assume for simplicity that A and B are square of size n
           ... assume for simplicity that n is a power of 2
           if n=1 
              C = A*B  ... scalar multiplication
           else
              P1 = Strassen(A12-A22,B21+B22) ... and 6 more similar lines
              C11 = P1+P2-P4+P6 ... and 3 more similar lines
           end

        A(n) = #arithmetic operations
             = 7*A(n/2)  ... 7 recursive calls
               + 18*(n/2)^2  ... 18 additions of n/2 x n/2 matrices

        solve as before: change variables from n=2^m to m, A(n) to a(m):
           a(m) = 7*a(m-1) + (9/2)*2^(2m),    a(0)=A(1)=1
        divide by 7^m, change again to b(m) = a(m)/7^m
           b(m) = b(m-1) + (9/2)*(4/7)^m,   b(0)=a(0)=1
                = O(1)  ... a convergent geometric sequence
        so
           A(n) = a(log2 n) = b(log2 n)*7^(log2 n) = O(7^(log2 n)) = O(n^(log2 7))

        Strassen is not (yet) used in practice
           can pay off for n not too large (a few hundred)
           error analysis slightly worse than usual algorithm:
                Usual Matmul (Q1.10):  |fl(A*B) - (A*B)| <= n*eps*|A|*|B|
                Strassen   ||fl(A*B) - (A*B)|| <= O(eps)*||A||*||B||
             which is good enough for lots of purposes, but when can
             Strassen's be much worse? (suppose one row of A, or one column of B, very tiny)
           Not allowed to be used in "LINPACK Benchmark" (www.top500.org)
             which ranks machines by speed = (2/3)*n^3 / time_for_GEPP(n)
             and if you use Strassen (we'll see how) in GEPP, it does << (2/3)n^3 flops
             so computed "speed" could exceed actually peak machine speed
           As a result, vendors stopped optimizing Strassen, even though it is a good idea!

       Current world's record
          Coppersmith & Winograd (1990): O(n^2.376)
            impractical: would need huge n to pay off
          Rederived by Cohn, Umans, Kleinberg, Szegedy (2003,2005)
             using group representation theory
       All such algorithms shown numerically stable (in norm-sense)
             by D., Dumitriu, Holtz, Kleinberg (2007)