Notes for Math 221, Lecture 5, Sep 10 2009

Goal: continue with Gaussian Elimination

   Overall plan:
      How to implement matmul and then GE efficiently
      Important variations exploiting structure (spd, banded, sparse,...)
      Iterative refinement to improve an approximate answer

   Simple alpha/beta/gamma performance model, Computational intensity

   BLAS history, with computational intensities

   MatMul
     naive, blocked, analysis
     recursive, analysis
     lower bound proof
     parallel case

   GE - same progression

Recall: cost of an algorithm is both arithmetic and moving data
    Picture of 2 level memory hierarchy
    Simplest cost model:
        time = #flops*time_per_flop + #words_moved*time_per_word
        Def: Bandwidth = BW = words/second, so time_per_word = 1/BW
        time_per_word >> time_per_flop, 
          Ex: easily O(100x) for moving data between main memory and on-chip cache
              and orders of magnitude larger for disk, etc
          So #word_moved*time_per_word may be >> #flops*time_per_flop unless we're clever
          Both time_per_word and time_per_flop improving exponentially over time,
             but time_per_word at ~60%/year, and time_per_word for main memory at ~23%/year, 
             so gap growing exponentially too, 
          So minimizing #words_moved only getting more important
    Picture of simplest parallel processor
       same simple model, 

    In both cases, for simplicity we will assume we (the algorithm designer) decides 
       when to move which data; this is the case when writing parallel code with MPI, 
       but not exactly the case with a cache (the hardware decides).
    
    More accurate model we will sometimes use:
        time = #flops*time_per_flop + #words_moved*time_per_word + #messages*time_per_message
        another word for time_per_message is "latency"
        Ex: Disk: suppose you ask for 1KB of data from consecutive disk locations:
            Need to move disk head to right track, wait for disk to spin to start of data: latency
            Start reading data off disk, time_per_word depends on rotational speed and bit density: BW
            "message" is all data fetched from consecutive locations
        Ex: Latency of highway from City A to City B is distance/speed limit
            Bandwidth of highway is #cars/hour that can get from A to B, depends on
                 how far apart cars drive, #lanes, not distance
        Similar ideas for Main memory:
            flops improving at 60%/year
            memory BW improving at 23%/year
            memory latency improving at 7%/year
        Similar ideas for messages sent between processors on a network
        Notation: time = #flops * gamma + #words_moved * beta + #messages * alpha
    
    Real machines even more complicated
        Picture of multi-level mem hierarchy
        Heterogeneous processors (running at very different speeds)
        Cost model applied to moving data between every level, 
        Sometimes we know how to design algorithms to minimize data movement at every level, 
            but many open questions.
  
   Back to simple model: time  = #flops*gamma + #words_moved*beta

    Simple metric of how much an algorithm communicates, vs how much it computes:
      Computational Intensity = q = #flops / #words_moved
      time = #flops*gamma + #words_moved*beta
           = #flops * gamma * ( 1 + (beta/gamma) /q )
    So time >= #flops*gamma, which is a lower bound, and is bigger by a term which
         has big and growing numerator beta/gamma, so our goal must be to maximize q.
    Worst case: 
        for i = 1 to n, read a word, do a flop, write a word
        Ex: for i=1 to n, A(i) = 2*A(i)
        q = 1/2
        time = #flops*gamma * (1 + (beta/gamma)*2) dominated by beta term
   
    Brief history of linear algebra software, measured using q
       First libraries (EISPACK, for eigenvalues of dense matrices, mid 1960s)
            didn't worry about data motion, arithmetic was more expensive,
            mostly just tried to get answer reliably 
       BLAS-1 library (mid 1970s) - Basic Linear Algebra Subroutines
            Standard library of 15 operations (mostly) on vectors:
              (1)  y = alpha*x + y, x and y vectors, alpha scalar ("AXPY" for short)
                   Ex: innermost loop in GE:
                          for k=i+1:n, A(j,k) = A(j,k) - A(j,i)*A(i,k) 
              (2)  dot product
              (3)  2-norm(x)
              (4)  find largest entry in absolute values in a vector (for pivoting in GE)
            Motivations at the time: 
               Commonly used functions to ease programming, readability
               robustness (eg avoid over/underflow in 2-norm(x))
               portable and efficient (if optimized on each architecture)
            But what about q? Not large
              (1) q = (2*n)/(3*n+1) ~ 2/3
              (2) q = (2*n)/(2*n+1) ~ 1
            So if we implement GE by loops calling AXPY, it will have q ~ 2/3
       BLAS-2 library (mid 1980s)
           Standard library of 25 operations (mostly) on matrix-vector pairs
              (1) y = alpha*y + beta*A*x ("GEMV" for short), 
                      lots of variations for different matrix structures 
                            (symmetric, banded, triangular etc)
                      allow A^T, A^* to be used instead of A, 
                      obvious optimizations when alpha and/or beta is +1, -1 or 0
              (2) A = A + alpha*x*y^T (matrix + rank-1 update) ("GER" for short)
                  Ex:  2 innermost loops in GE:
                         A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - A(i+1:n,i)*A(i,i+1:n)
              (3) Solve T*x=b where T triangular (used by GE to solve A*x=b)
                     ("TRSV" for short)
           Similar motivation for BLAS-1, plus more opportunies to optimize
           But what about q? A little better, but not big enough (assume A n x n)
              (1) q = (2*n^2 + 3*n)/(n^2+3*n) ~ 2
              (2) q = (3*n^3)/(2*n^2 + 3*n + 1) ~ 3/2
              (3) q = (n^2)/(n^2/2 + 2*n) ~ 2
      BLAS-3 library (late 1980s)
           Standard library of 9 operations on matrix-matrix pairs
              (1) C = beta*C + alpha*A*B ("GEMM" for short)
              (2) C = beta*C + alpha*A*A^T, C symmetric ("SYRK" for short)
              (3) Solve T*X=B where T triangular, X and B matrices ("TRSM" for short)
           What about q? (suppose matrices are n x n)
              (1) q = (2*n^3 + 3*n^2)/(4*n^2) ~ n/2    big!
              (2) q = (n^3 + 1.5*n^2)/(2*n^2) ~ n/2    big!
              (3) q = (n^3 + 1.5*n^2)/(2*n^2) ~ n/2    big!
           So with BLAS-3, we can (for large n) hope to move less data
            (if we are clever)
           So if we use BLAS-3 to build GE (and other routines), we can
            also hope to move less data 
              Not obvious how to do this: we will discuss this
      Note: BLAS-k does O(n^k) operations, for k=1,2,3
      Note: These are supplied as part of the optimized math libraries on 
               essentially all computers: use them as your building blocks!
            Reference implementations and documentation at www.netlib.org/blas

      (show slide of different GEMM speeds)

      How to implement Matmul to minimize data movement:

      First look at naive implementation, what's wrong with it, then good one:

           ... multiply C = C + A*B
           for i = 1:n
              for j = 1:n
                 for k = 1:n
                   C(i,j) = C(i,j) + A(i,k)*B(k,j)

        Need to "decorate" code with explicit descriptions of data movement 
            (recall: we assume we control it!)

        Answer will depend on size of fast memory, call it M

	Easy case: A, B and C all fit in fast memory (i.e. 3*n^2 <= M),
	    ... load all of A, B, C into fast memory
            multiply them using above algorithm
            ... store answers

            #words_moved = 4*n^2 (= lower bound = #inputs + #outputs)
            
        Interesting case: when matrices don't all fit (3*n^2 > M)

        Intuition: innermost loop is dot-product, so q ~ 1
                   innermost 2 loops are GEMV: C(i,:) = C(i,:) + A(i,:)*B, so q ~ 2

        Simplest way to do it: load data just before you need it

           for i = 1:n
              for j = 1:n
                 ... load C(i,j)
                 for k = 1:n
                   ... load A(i,k), B(k,j)
                   C(i,j) = C(i,j) + A(i,k)*B(k,j)
                 ... store C(i,j)

            #words_moved = 2n^3 + n^2, q ~ 1  (like dot product, not good)

         a little better, but not much: 
              assume row of A and of C, both fit in fast memory (2*n <= M)

           for i = 1:n
              ... load whole rows A(i,:), C(i,:)
              for j = 1:n
                 for k = 1:n
                   ... load B(k,j)
                   C(i,j) = C(i,j) + A(i,k)*B(k,j)
              ... store row C(i,:)

            #words_moved = n^3 + 3*n^2, q ~ 2  (like GEMV)

         Notation: think of A, B, C as collection of b x b subblocks,
           so that C(i,j) is b x b, numbered from C(1,1) to C(n/b,n/b)

           for i=1:n/b
              for j=1:n/b
                 for k=1:n/b
                    C(i,j) = C(i,j) + A(i,k)*B(k,j)   ... b x b matrix multiplication

         Suppose that 3 bxb blocks fit in fast memory (3*b^2 <= M), then
         we can move data as follows:

           for i=1:n/b
              for j=1:n/b
                 ... load block C(i,j)
                 for k=1:n/b
                    ... load blocks A(i,k) and B(k,j)
                    C(i,j) = C(i,j) + A(i,k)*B(k,j)   ... b x b matrix multiplication
                 ... store block C(i,j)

           #words_moved = 2*(n/b)^2 * b^2    ... for C
                          + 2*(n/b)^3 * b^2    ... for A and B
                        = 2*n^2 + 2*n^3/b   

           Much better: dominant term smaller by a factor of b
           q = 2*n^3/(2*n^2 + 2*n^3/b) ~ b

           We want b as large as possible to minimize #words_moved.
           How large can we make it? 3*b^2 <= M so
             #words_moved >= 2*n^2 + sqrt(3)*2*n^3/sqrt(M) = Omega(n^3/sqrt(M))
             q = O(sqrt(M))

        Theorem: no matter how we reorganize matmul, as long as we
                 do the usual 2*n^3 arithmetic operations, but possibly
                 just doing the sums in different orders, then
                 #words_moved >= Omega(n^3/sqrt(M))

        History: Hong, Kung, 1981
                 Irony, Tiskin, Toledo, 2004 (parallel case too)
                 Ballard, D., Holtz, Schwartz, 2009 (rest of linear algebra)

        Proof (ITT): Think of your matmul program as doing some sequence of
           instructions: add, mul, load, store, ...
              arithmetic like add and mul only uses data in fast memory
              load and store move data between fast and slow memory
           Now just look at all the load/store instructions, 
             say there are W (=#words_moved) of them. We want a lower bound on W.

           Let S(1) be the set of all instructions from the start up to
             the M-th load/store instruction
           Let S(2) be the set of all instructions after S(1), up to 
             the 2M-th load/store instruction
           ...
           Let S(i) be the set of all instructions after S(i-1), up to
             the i*M-th load/store instruction
           ...
           Let S(r) be the last set of insructions (which may have
             fewer than M load/store instructions, if M does not divide W evenly)
                 
           Now (r-1)*M <= W, so a lower bound on r gives a lower bound on W.
           
           Suppose we had an upper bound F on the number of adds & muls that
           could be in any S(t). Since we need to do 2*n^3 adds & muls,
           that tells us that r*F >= 2*n^3 or r >= 2*n^3/F, or
             W >= (r-1)*M >= (2*n^3/F -1)*M is a lower bound.

           How do we get an upper bound F? We need to ask what adds and muls
           we can possibly do in one set S(t). We will bound this by using
           the fact that during S(t), only a limited amount of data is 
	   available in fast memory on which to do arithmetic:
	      There are the M words in fast memory when S(t) starts,
                 eg up to M different entries of A, B, partially computed entries of C
              We can do as many as M more loads during S(t) 
                 eg up to M more entries of A, B, partially computed entries of C
              We can do as many as M stores during S(t)
                 eg up to M more partially or fully computed entries of C
           Altogether, there are at most 3*M words of data on which to do adds&muls
           How many can we do?

           Clever idea of ITT: represent possible data by 3 faces of an nxnxn cube
               each A(i,k) represented by 1x1 square on one face
               each B(k,j) represented by 1x1 square on another, not parallel, face
               each C(i,j) represented by 1x1 square on another, not parallel, face
           and pair of operations C(i,j) = C(i,j) + A(i,k)*B(k,j) by 1x1x1 cube
               with corresponding projections on each face
           Counting problem: we know the total number of squares on each face that
               are "occupied" is at most 3*M. How many cubes can these "cover"?
           (Picture)

           Easy case: occupied C(i,j) fill up XxY rectangle
                      occupied A(i,k) fill up XxZ rectangle
                      occupied B(k,j) fill up ZxY rectangle
               "covered cubes" occupy brick of dimension XxYxZ,
                so at most X*Y*Z of them
           General case: what if we don't have rectangles and bricks?
                Note that volume=XYZ can also be written
                     XYZ = sqrt((X*Y)*(X*Z)*(Z*Y))
                         = sqrt((#C entries)*(#A entries)*(#B entries))
           Lemma (Loomis&Whitney, 1949)
                Let S be any bounded 3D set
                Let S_x be projection ("shadow") of S on y,z - plane
                  (i.e. if (x,y,z) in S then (y,z) in S_x
                Similarly, let S_y and S_z be shadows on other planes
                Then Volume(S) <= sqrt(area(S_x)*area(S_y)*area(S_z))

           Apply to our case:
                S = set of 1x1x1 cubes representing all mul/add pairs
                  = {cubes centered at (i,j,k) representing C(i,j) = C(i,j) + A(i,k)*B(k,j)}
                S_x = set of entries of B needed to perform these
                    = {squares centered at (k,j) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)}
                S_y = set of entries of A needed to perform them
                    = {squares centered at (i,k) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)}
                S_z = set of entries of C needed to perform them
                    = {squares centered at (i,j) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)}

           Lemma => volume(S) = #cubes in S = #flops/2
                              <= sqrt(area(S_x)*area(S_y)*area(S_z))
                              =  sqrt((#C entries)*(#A entries)*(#B entries))
                              <= sqrt((3*M)*(3*M)*(3*M))
           or #flops = F <= 6*sqrt(3)*M^(1.5)

           Finally
             #words_moved = W 
                         >= (r-1)*M 
                         >= (2*n^3/F -1)*M 
                         >= (2*n^3/(6*sqrt(3)*M^(1.5)) -1)*M 
                         = (1/sqrt(27))*n^3/sqrt(M) - M
                         = Omega(n^3/sqrt(M))  as desired