Notes for Math 221, Lecture 8, Sep 22 2009

Summary of questions we have asked about Matmul:

(1) What is simplest (sequential) algorithm, and how much does it cost
    in #flops?   in #words_moved (between slow memory and cache)?
    Answer:  2n^3, O(n^3)

(2) Is there an algorithm that moves fewer words? How many?
    Answer: blocked matrix multiply, with b x b blocks, 
      where 3 blocks fit in cache of size M:  3b^2 <= M
      Moves just O(n^3/sqrt(M)) words

(3) Can we do better?
    Answer: No: Omega(n^3/sqrt(M)) is a lower bound

(4) Is there a "cache oblivious" algorithm, i.e. that doesn't
    depend explicitly on M, that minimizes #words_moved?
   Answer: Yes: recursive matmul

(5) How fast can we hope that parallel matmul could run on P processors?
   Answer: Each processor computes as little as possible: 2n^3/P
    and communicates as little as possible: 
      Omega(n^2/sqrt(P)) words_moved in Omega(sqrt(P)) messages

(6) Is there a parallel Matmul algorithm that runs this fast?
   Answer: Yes: Cannon's algorithm

(7) Can we do matmul faster than 2n^3?
   Answer: Yes: Strassen's method takes O(n^(log2 7)) flops
            almost as stably
           The world's record is O(n^2.376), but impractical

(8) Do we know how to minimize #words_moved for algorithms like Strassen?
   Answer: we think so, but work in progress...

(9) Is this all I need to know to write a fast matmul algorithm?
   Answer: No, lots of computer-specific details to get right
    (see graphs at www.cs.berkeley.edu/~volkov/cs267.sp09/hw1
     www.cs.berkeley.edu/~volkov/cs267.sp09/hw1/results
     for evidence)

(10) So are there good implementation I can just use, and not
    worry about the details?
   Answer: Yes: the BLAS and PBLAS libraries are typically
     provided in math libraries tuned for each computer
     If not, ATLAS (www.netlib.org/atlas) will tune BLAS for you automatically.
     Matlab generally uses tuned BLAS libraries.

We can ask the same set of questions for all the important
operations of linear algebra (solving Ax=b, least squares,
eigenvalues & eigenvectors, SVD), and get very similar answers
for "direct methods" like Gaussian elimination (iterative
methods, like conjugate gradients, are different).
There are also a number of open problems! In particular, the software
still has to catch up with our theoretical understanding of
how to communicate as little as possible.

We won't have time to answer all 10 questions for every linear algebra 
problem in as much detail as for matmul. But we will highlight
the parts that are most different or interesting for each problem.

Outline:
  Lower bound for communication with GE
  GE from BLAS2 to BLAS3
  Recursive GE
  Parallel GE

  Then on to 
     Iterative refinement (to get more accuracy)
     LU for special structures (symmetric, sparse, etc)

  Why does GE have same lower bound on #words moved as Matmul?
     Because it is has matmul "inside" it:
        [ I 0 -B ] = [ I 0 0 ] * [ I 0 -B ] = L * U
        [ A I  0 ]   [ A I 0 ]   [ 0 I A*B]
        [ 0 0  I ]   [ 0 0 I ]   [ 0 0  I ]
     In other words, a GE subroutine (no pivoting) has to move at least
     as many words as a matmul subroutine on matrices of 1/3 the size,
     so some constant times n^3/sqrt(cache size) #words_moved.

   Recall existing GEPP om mxn matrix: it is a BLAS 2 code
   
     for i = 1 to n-1
        pivot so |A(i,i)| largest among |A(i:m,i)|
        j=i+1
        A(j:m,i) = A(j:m,i)/A(i,i)   ... BLAS 1 - scale a vector
        A(j:m,j:n) = A(j:m,j:n) - A(j:m,i)*A(i,j:n)   ... BLAS2 - rank-1 update

   picture

   General technique to convert BLAS2 code to BLAS3 code (often works):
       Note that a sequence of rank-1 updates A = A - u(i)*v(i)^T
       can also be written as A = A - U*V where U=[...u(i)...], V=[...v(i)...]
       i.e. as gemm. 
   Question: in GEPP, each BLAS2 calls appears to depend on the previous one,
       to get the vectors A(j:n,i) = u(i) and A(i,j:n) = v(i)^T, so
       how do we reorganize GEPP to get several at a time?


   Getting started:( A11 A12 ) = ( L11 0 ) * ( U11 U12 )
                   ( A21 A22 )   ( L21 I )   (  0  S   )

            where  A11, L11, U11 are bxb, and S is the Schur Complement
            how to compute block entrie of L and U:

                   ( A11 A12 ) = ( L11 0 ) * ( U11 U12 ) = ( L11*U11 L11*U12 )
                   ( A21 A22 )   ( L21 I )   (  0  S   )   ( L21*U11 S+L21*U12)

       => ( A11 ) = ( L11 ) * U11
          ( A21 )   ( L21 )
        this is LU factorization of leading b columns of A; use current algorithm:

             for i=1:b
               pivot
               L(i+1:n,i) = A(i+1:n,i)/A(i,i)   ... L overwrites A
               U(i,i:b) = A(i,i:b)              ... U overwrites A (nothing to do)
               A(i+1:n,i+1:b) -= L(i+1:n,i)*U(i,i+1:b)
         cost = O(n b^2), BLAS2 (rank-1 update)

        => A12 = L11*U12 => U12 = inv(L11)*A12
            triangular solve with many right-hand-sides (BLAS3) O(n b^2) flops

        => S = A22 - L21*U12, gemm with O(n^2 b) flops
             => most flops are BLAS3

   Continuing: Suppose we have finished computing leading columns of L, rows of U;
       write (partial) factorization of A as 
               A = ( A11 A12 A13 ) = ( L11  0   0  ) * ( U11 U12  U13 )
                   ( A21 A22 A23 )   ( L21  I   0  )   (  0 ~A22 ~A23 )
                   ( A31 A32 A33 )   ( L31  0   I  )   (  0 ~A32 ~A33 )

     where A22 is b-by-b; apply same idea to [~A22 ~A23] to get
                                             [~A32 ~A33]

            [ ~A22 ~A23 ] = [ L22   0  ] * [ U22  U23 ]
            [ ~A32 ~A33 ]   [ L32   I  ]   [  0   ~S  ]

     where ~S = ~A33 - L32*U23, so

          A = ( A11 A12 A13 ) = ( L11  0   0  ) * ( U11 U12  U13 )
              ( A21 A22 A23 )   ( L21  I   0  )   (  0 ~A22 ~A23 )
              ( A31 A32 A33 )   ( L31  0   I  )   (  0 ~A32 ~A33 )
            = ( L11  0   0  ) * ( U11    U12     U13              )
              ( L21  I   0  )   (  0 [ L22   0  ] * [ U22  U23 ]  )
              ( L31  0   I  )   (  0 [ L32   I  ]   [  0   ~S  ]  )
            = ( L11  0   0  ) * ( U11    U12  U13 )
              ( L21  L22 0  )   (  0     U22  U23 )
              ( L31  L32 I  )   (  0     0   ~S   )

     and computation has advanced b steps. The overall algorithm is

  for i = 1 to n-1 step b

       Factor A(i:n,i:i+b-1)  = ( L22 ) * U22 using BLAS2 algorithm
                                ( L32 )

       Apply pivot interchanges to rest of matrix

       A(i:i+b-1, i+b:n) = inv(A(i:i+b-1,i:i+b-1)) * A(i:i+b-1, i+b:n)
              ...  U23 = inv(L22)*~A23 , a single BLAS3 call

       A(i+b:n, i+b:n) = A(i+b:n, i+b:n) - A(i+b:n,i:i+b-1) * A(i:i+b-1,i+b:n)
              ... ~S = ~A33 - L32*U23 , a single BLAS3 call


       How do we pick b?
         b=1 corresponds to classical BLAS-2 algorithm,
         BLAS-3 calls has computational intensity
           at most about 2*n^2*b/(2*n*b+2*n^2) = b 
         So want to pick b larger to make BLAS-3 calls faster
         But of 2n^3/3 total flops: number in BLAS-2 is about
               n*b*(n/2 - b/3)
         so as b grows, more flops in (slow) BLAS-2, fewer in (fast) BLAS-3
         How to pick b? 
            Theory: Depending on relative sizes of M, n may not be possible
                    to pick b to minimize #words_moved 
            Practice: some libraries tuned for particular architecture,
                      portable ones could (should) be tuned, 
                          but default in LAPACK is 64 (ilaenv)

History: LINPACK and EISPACK in 1970s, 1980s did not block
         LAPACK, starting in late 1980s redid many linear algebra
            algorithms (not just GEPP) to use BLAS3 as much
            as possible
         now part of all major vendor libraries, Matlab, etc.
         all available at www.netlib.org/lapack
            (Fortran, C, Fortran95, other versions)
         parallel version at www.netlib.org/ScaLAPACK
         (over 100M hits to web site)

Just as there was a recursive version of GEMM that minmized
communication without explicity depending on the cache size M,
We can do the same for LU (due to Toledo, 1997):

   At top level:  "Do LU on left half of matrix
                   Update right half (U at top, Schur complement at bottom)
                   Do LU on Schur complement"

      function [L,U] = RLU(A)  
        ... assume A is n by m with n >= m, m a power of 2
        if m=1 ... one column
           pivot so largest entry on diagonal, update rest of matrix
           L = A/A(1), U = A(1)
      else
           ... write A = [ A11 A12 ], L1 = [ L11 ]
           ...           [ A21 A22 ]       [ L12 ]
           ... where A11, A12, L11, U1 and U2 are m/2 by m/2
	   ...       A21, A22 and L12 are n-m/2 by m/2
           [L1,U1] = RLU([ A11 ]) ... LU of left half of A
                         [ A21 ]
           A12 = L11 \ A12      ... update U, in upper part of right half of A
           A22 = A22 - L12*A12  ... update Schur complement
           [L2,U2] = RLU( A22 } ... LU on Schur complement
           L = [ L1, [0;L2] ]   ... return complete mxn L factor
           U = [ U1  A12 ]      ... return complete mxm U factor
               [  0  U2  ]
           
     Correct by induction on m
     Analysis: see homework
     Fact: if we do L12*A12 by Strassen, and L11\A12 by 
           (1) inverting L11 by divide-and-conquer:
                inv([ T11 T12 ]) = [ inv(T11) -inv(T11)*T12*inv(T22) ]
                   ([  0  T22 ])   [    0            inv(T22)        ]
           (2) multiplying inv(L11)*A12 by Strassen
     then whole algorithm costs O(n^(log2 7)) like Strassen, but can be less stable 
     than usual O(n^3) version of GEPP
     (see arxiv.org/abs/math.NA/0612264)

Last GE implementation: Parallel GE
           
     (Block cyclic layout slides - slide 27 from Lecture 11 of CS267)
     (Parallel GE slides - slides 28-29 from Lecture 11 of CS267)

     However, this is not quite as good as Cannon's MatMul algorithm:
        #flops per processor: about 2n^3/(3*P) - ok, each processor does 1/Pth of work
        #words_moved per processor: O(n^2*log(P)/sqrt(P)) - larger than
           lower bound, but just a factor of log(P), which grows slowly
        #messages: O(n*log(P)) - can be much larger than lower bound of Omega(sqrt(P))
     Seems impossible to preserve partial pivoting and do better:
        need to compute largest entry in each column, so #messages grows like Omega(n)
     For an alternative to partial pivoting that (nearly) reaches lower bound, 
        and gets interesting speedups, see Grigori, D., Xiang (SC08) 
        http://hal.inria.fr/inria-00277901/fr/