Notes for Math 221, Lecture 5, Sep 10 2009 Goal: continue with Gaussian Elimination Overall plan: How to implement matmul and then GE efficiently Important variations exploiting structure (spd, banded, sparse,...) Iterative refinement to improve an approximate answer Simple alpha/beta/gamma performance model, Computational intensity BLAS history, with computational intensities MatMul naive, blocked, analysis recursive, analysis lower bound proof parallel case GE - same progression Recall: cost of an algorithm is both arithmetic and moving data Picture of 2 level memory hierarchy Simplest cost model: time = #flops*time_per_flop + #words_moved*time_per_word Def: Bandwidth = BW = words/second, so time_per_word = 1/BW time_per_word >> time_per_flop, Ex: easily O(100x) for moving data between main memory and on-chip cache and orders of magnitude larger for disk, etc So #word_moved*time_per_word may be >> #flops*time_per_flop unless we're clever Both time_per_word and time_per_flop improving exponentially over time, but time_per_word at ~60%/year, and time_per_word for main memory at ~23%/year, so gap growing exponentially too, So minimizing #words_moved only getting more important Picture of simplest parallel processor same simple model, In both cases, for simplicity we will assume we (the algorithm designer) decides when to move which data; this is the case when writing parallel code with MPI, but not exactly the case with a cache (the hardware decides). More accurate model we will sometimes use: time = #flops*time_per_flop + #words_moved*time_per_word + #messages*time_per_message another word for time_per_message is "latency" Ex: Disk: suppose you ask for 1KB of data from consecutive disk locations: Need to move disk head to right track, wait for disk to spin to start of data: latency Start reading data off disk, time_per_word depends on rotational speed and bit density: BW "message" is all data fetched from consecutive locations Ex: Latency of highway from City A to City B is distance/speed limit Bandwidth of highway is #cars/hour that can get from A to B, depends on how far apart cars drive, #lanes, not distance Similar ideas for Main memory: flops improving at 60%/year memory BW improving at 23%/year memory latency improving at 7%/year Similar ideas for messages sent between processors on a network Notation: time = #flops * gamma + #words_moved * beta + #messages * alpha Real machines even more complicated Picture of multi-level mem hierarchy Heterogeneous processors (running at very different speeds) Cost model applied to moving data between every level, Sometimes we know how to design algorithms to minimize data movement at every level, but many open questions. Back to simple model: time = #flops*gamma + #words_moved*beta Simple metric of how much an algorithm communicates, vs how much it computes: Computational Intensity = q = #flops / #words_moved time = #flops*gamma + #words_moved*beta = #flops * gamma * ( 1 + (beta/gamma) /q ) So time >= #flops*gamma, which is a lower bound, and is bigger by a term which has big and growing numerator beta/gamma, so our goal must be to maximize q. Worst case: for i = 1 to n, read a word, do a flop, write a word Ex: for i=1 to n, A(i) = 2*A(i) q = 1/2 time = #flops*gamma * (1 + (beta/gamma)*2) dominated by beta term Brief history of linear algebra software, measured using q First libraries (EISPACK, for eigenvalues of dense matrices, mid 1960s) didn't worry about data motion, arithmetic was more expensive, mostly just tried to get answer reliably BLAS-1 library (mid 1970s) - Basic Linear Algebra Subroutines Standard library of 15 operations (mostly) on vectors: (1) y = alpha*x + y, x and y vectors, alpha scalar ("AXPY" for short) Ex: innermost loop in GE: for k=i+1:n, A(j,k) = A(j,k) - A(j,i)*A(i,k) (2) dot product (3) 2-norm(x) (4) find largest entry in absolute values in a vector (for pivoting in GE) Motivations at the time: Commonly used functions to ease programming, readability robustness (eg avoid over/underflow in 2-norm(x)) portable and efficient (if optimized on each architecture) But what about q? Not large (1) q = (2*n)/(3*n+1) ~ 2/3 (2) q = (2*n)/(2*n+1) ~ 1 So if we implement GE by loops calling AXPY, it will have q ~ 2/3 BLAS-2 library (mid 1980s) Standard library of 25 operations (mostly) on matrix-vector pairs (1) y = alpha*y + beta*A*x ("GEMV" for short), lots of variations for different matrix structures (symmetric, banded, triangular etc) allow A^T, A^* to be used instead of A, obvious optimizations when alpha and/or beta is +1, -1 or 0 (2) A = A + alpha*x*y^T (matrix + rank-1 update) ("GER" for short) Ex: 2 innermost loops in GE: A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - A(i+1:n,i)*A(i,i+1:n) (3) Solve T*x=b where T triangular (used by GE to solve A*x=b) ("TRSV" for short) Similar motivation for BLAS-1, plus more opportunies to optimize But what about q? A little better, but not big enough (assume A n x n) (1) q = (2*n^2 + 3*n)/(n^2+3*n) ~ 2 (2) q = (3*n^3)/(2*n^2 + 3*n + 1) ~ 3/2 (3) q = (n^2)/(n^2/2 + 2*n) ~ 2 BLAS-3 library (late 1980s) Standard library of 9 operations on matrix-matrix pairs (1) C = beta*C + alpha*A*B ("GEMM" for short) (2) C = beta*C + alpha*A*A^T, C symmetric ("SYRK" for short) (3) Solve T*X=B where T triangular, X and B matrices ("TRSM" for short) What about q? (suppose matrices are n x n) (1) q = (2*n^3 + 3*n^2)/(4*n^2) ~ n/2 big! (2) q = (n^3 + 1.5*n^2)/(2*n^2) ~ n/2 big! (3) q = (n^3 + 1.5*n^2)/(2*n^2) ~ n/2 big! So with BLAS-3, we can (for large n) hope to move less data (if we are clever) So if we use BLAS-3 to build GE (and other routines), we can also hope to move less data Not obvious how to do this: we will discuss this Note: BLAS-k does O(n^k) operations, for k=1,2,3 Note: These are supplied as part of the optimized math libraries on essentially all computers: use them as your building blocks! Reference implementations and documentation at www.netlib.org/blas (show slide of different GEMM speeds) How to implement Matmul to minimize data movement: First look at naive implementation, what's wrong with it, then good one: ... multiply C = C + A*B for i = 1:n for j = 1:n for k = 1:n C(i,j) = C(i,j) + A(i,k)*B(k,j) Need to "decorate" code with explicit descriptions of data movement (recall: we assume we control it!) Answer will depend on size of fast memory, call it M Easy case: A, B and C all fit in fast memory (i.e. 3*n^2 <= M), ... load all of A, B, C into fast memory multiply them using above algorithm ... store answers #words_moved = 4*n^2 (= lower bound = #inputs + #outputs) Interesting case: when matrices don't all fit (3*n^2 > M) Intuition: innermost loop is dot-product, so q ~ 1 innermost 2 loops are GEMV: C(i,:) = C(i,:) + A(i,:)*B, so q ~ 2 Simplest way to do it: load data just before you need it for i = 1:n for j = 1:n ... load C(i,j) for k = 1:n ... load A(i,k), B(k,j) C(i,j) = C(i,j) + A(i,k)*B(k,j) ... store C(i,j) #words_moved = 2n^3 + n^2, q ~ 1 (like dot product, not good) a little better, but not much: assume row of A and of C, both fit in fast memory (2*n <= M) for i = 1:n ... load whole rows A(i,:), C(i,:) for j = 1:n for k = 1:n ... load B(k,j) C(i,j) = C(i,j) + A(i,k)*B(k,j) ... store row C(i,:) #words_moved = n^3 + 3*n^2, q ~ 2 (like GEMV) Notation: think of A, B, C as collection of b x b subblocks, so that C(i,j) is b x b, numbered from C(1,1) to C(n/b,n/b) for i=1:n/b for j=1:n/b for k=1:n/b C(i,j) = C(i,j) + A(i,k)*B(k,j) ... b x b matrix multiplication Suppose that 3 bxb blocks fit in fast memory (3*b^2 <= M), then we can move data as follows: for i=1:n/b for j=1:n/b ... load block C(i,j) for k=1:n/b ... load blocks A(i,k) and B(k,j) C(i,j) = C(i,j) + A(i,k)*B(k,j) ... b x b matrix multiplication ... store block C(i,j) #words_moved = 2*(n/b)^2 * b^2 ... for C + 2*(n/b)^3 * b^2 ... for A and B = 2*n^2 + 2*n^3/b Much better: dominant term smaller by a factor of b q = 2*n^3/(2*n^2 + 2*n^3/b) ~ b We want b as large as possible to minimize #words_moved. How large can we make it? 3*b^2 <= M so #words_moved >= 2*n^2 + sqrt(3)*2*n^3/sqrt(M) = Omega(n^3/sqrt(M)) q = O(sqrt(M)) Theorem: no matter how we reorganize matmul, as long as we do the usual 2*n^3 arithmetic operations, but possibly just doing the sums in different orders, then #words_moved >= Omega(n^3/sqrt(M)) History: Hong, Kung, 1981 Irony, Tiskin, Toledo, 2004 (parallel case too) Ballard, D., Holtz, Schwartz, 2009 (rest of linear algebra) Proof (ITT): Think of your matmul program as doing some sequence of instructions: add, mul, load, store, ... arithmetic like add and mul only uses data in fast memory load and store move data between fast and slow memory Now just look at all the load/store instructions, say there are W (=#words_moved) of them. We want a lower bound on W. Let S(1) be the set of all instructions from the start up to the M-th load/store instruction Let S(2) be the set of all instructions after S(1), up to the 2M-th load/store instruction ... Let S(i) be the set of all instructions after S(i-1), up to the i*M-th load/store instruction ... Let S(r) be the last set of insructions (which may have fewer than M load/store instructions, if M does not divide W evenly) Now (r-1)*M <= W, so a lower bound on r gives a lower bound on W. Suppose we had an upper bound F on the number of adds & muls that could be in any S(t). Since we need to do 2*n^3 adds & muls, that tells us that r*F >= 2*n^3 or r >= 2*n^3/F, or W >= (r-1)*M >= (2*n^3/F -1)*M is a lower bound. How do we get an upper bound F? We need to ask what adds and muls we can possibly do in one set S(t). We will bound this by using the fact that during S(t), only a limited amount of data is available in fast memory on which to do arithmetic: There are the M words in fast memory when S(t) starts, eg up to M different entries of A, B, partially computed entries of C We can do as many as M more loads during S(t) eg up to M more entries of A, B, partially computed entries of C We can do as many as M stores during S(t) eg up to M more partially or fully computed entries of C Altogether, there are at most 3*M words of data on which to do adds&muls How many can we do? Clever idea of ITT: represent possible data by 3 faces of an nxnxn cube each A(i,k) represented by 1x1 square on one face each B(k,j) represented by 1x1 square on another, not parallel, face each C(i,j) represented by 1x1 square on another, not parallel, face and pair of operations C(i,j) = C(i,j) + A(i,k)*B(k,j) by 1x1x1 cube with corresponding projections on each face Counting problem: we know the total number of squares on each face that are "occupied" is at most 3*M. How many cubes can these "cover"? (Picture) Easy case: occupied C(i,j) fill up XxY rectangle occupied A(i,k) fill up XxZ rectangle occupied B(k,j) fill up ZxY rectangle "covered cubes" occupy brick of dimension XxYxZ, so at most X*Y*Z of them General case: what if we don't have rectangles and bricks? Note that volume=XYZ can also be written XYZ = sqrt((X*Y)*(X*Z)*(Z*Y)) = sqrt((#C entries)*(#A entries)*(#B entries)) Lemma (Loomis&Whitney, 1949) Let S be any bounded 3D set Let S_x be projection ("shadow") of S on y,z - plane (i.e. if (x,y,z) in S then (y,z) in S_x Similarly, let S_y and S_z be shadows on other planes Then Volume(S) <= sqrt(area(S_x)*area(S_y)*area(S_z)) Apply to our case: S = set of 1x1x1 cubes representing all mul/add pairs = {cubes centered at (i,j,k) representing C(i,j) = C(i,j) + A(i,k)*B(k,j)} S_x = set of entries of B needed to perform these = {squares centered at (k,j) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)} S_y = set of entries of A needed to perform them = {squares centered at (i,k) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)} S_z = set of entries of C needed to perform them = {squares centered at (i,j) projected from C(i,j) = C(i,j) + A(i,k)*B(k,j)} Lemma => volume(S) = #cubes in S = #flops/2 <= sqrt(area(S_x)*area(S_y)*area(S_z)) = sqrt((#C entries)*(#A entries)*(#B entries)) <= sqrt((3*M)*(3*M)*(3*M)) or #flops = F <= 6*sqrt(3)*M^(1.5) Finally #words_moved = W >= (r-1)*M >= (2*n^3/F -1)*M >= (2*n^3/(6*sqrt(3)*M^(1.5)) -1)*M = (1/sqrt(27))*n^3/sqrt(M) - M = Omega(n^3/sqrt(M)) as desired