Notes for Math 221, Lecture 8, Sep 21 2010 Goal: Discuss how to solve Ax=b using variants of Gaussian Elimination (GE). Basic idea of GE as a matrix factorization (also known as LU decomposition). How it might be numerically unstable, or usually be stable if done properly (using pivoting) How to get an accurate answer at reasonable cost, even if condition number large, i.e. norm(x_computed - x_true)/norm(x_true) = O(macheps), not O(cond(A)*macheps) by using Newton's method and a little extra precision (just O(n^2)) Variations that minimize communication Variations that exploit the matrix structure to go (much) faster Symmetry, positive definiteness (eg Cholesky) Sparsity, when most entries zero When matrix entries depend on few parameters Classical example: polynomial interpolation same as solving V*a=y for vector a, where V(i,j) = x(i)^(j-1); then sum_{j=1}^{n} a(j)*x(i)^(j-1) = y(i), i.e. a(1:n) are coefficients of polynomial interpolating the points (x(1),y(1)), ... , (x(n),y(n)) can solve for a(1:n) in O(n^2) flops, using Newton interpolation, not O(n^3). Def Permutation matrix: identity with permuted rows Facts: let P, P1 etc be permutation matrices P*X = row permutation of X X*P = col permutation of X P1*P2 also a permutation matrix inv(P) = P^T (enough to check that diag(P*P^T) = all ones det(P) = +-1 P is an orthogonal matrix to store and multiply by P, just keep track of order of rows (cheap) Thm (LU decomposition): Given any m x n full rank matrix A, with m >= n, there exists m x m permutation matrix P m x n unit lower triangular matrix L n x n nonsingular upper triangular matrix U such that A = P*L*U Proof: induction / Gaussian elimination (soon) Cor: If A n x n and nonsingular, there exists n x n permutation P, lower triangular L, upper triangular U st A = P*L*U To solve A*x = b: (1) factor A = P*L*U (expensive part, (2/3)n^3) (2) Solve P*L*U*x = b for L*U*x = P^T*b ... cost = O(n) (3) Solve L*U*x = P^T*b for U*x = inv(L)*P^T*b using forward substitution ... cost = n^2 flops (4) Solve U*x = inv(L)*P^T*b for x = inv(U)*inv(L)*P^T*b = inv(A)*b using back substitution ... cost = n^2 flops Note: We do not compute inv(A) and multiply x = inv(A)*b: (1) 3x more expensive in dense case, can be much worse in sparse case (O(n) times!) (2) not as numerically stable Proof of Theorem: If A full rank, the first column is nonzero, so there is a permutation P such that (P*A)(1,1) is nonzero. Write P*A = [ A11 A12 ] = [ 1 0 ] * [ A11 A12 ] [ A21 A22 ] [ A21/A11 I ] [ 0 A22 -A21*A12/A11 ] where A11 is 1 x 1, A21 is (m-1) x 1, A12 is 1 x (n-1) and A22 is (m-1) x (n-1) Now A full (column) rank => P*A full rank => S = A22 - A21*A12/A11 full rank (Otherwise, if some linear combination of columns of S were 0, some linear combination of columns of A would be zero, contradicting A being full rank.) (Square case: det(A) nonzero => 0 neq det(first factor) * det(second factor) = 1 * A11 * det(S) => det(S) nonzero) Notation: S called Schur complement Now apply induction: S = P'*L'*U', so P*A = [ 1 0 ] * [ A11 A12 ] [ A21/A11 I ] [ 0 P'*L'*U'] = [ 1 0 ] * [ A11 A12 ] [ A21/A11 P'*L' ] [ 0 U' ] = [ 1 0 ] * [ 1 0 ] * [ A11 A12 ] [ 0 P' ] [ P'^T*(A21/A11) L' ] [ 0 U' ] = P'' * L * U = perm * unit_lower_triangular * nonsingular_upper_triangular so A = P^T*P'' * L * U = perm * L * U QED. Expressing this as an algorithm, we get the following (ignore permutations first) for i = 1 to n L(i,i) = 1, L(i+1:n,i) = A(i+1:n,i)/A(i,i) U(i,i:n) = A(i,i:n) if (i < n) A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - L(i+1:n,i)*U(i,i+1:n) Add permutatations: after for i=1 to n, add: if A(i,i) zero and A(j,i) nonzero, swap rows i and j of L and A; record swap Don't waste space: row i of U overwrites row of i A: omit U(i,i:n) = A(i,i:n) col i of L (below diagonal) overwrites same entries of A, which are zeroed out: change first line to A(i+1:n,i) = A(i+1:n,i)/A(i,i) only need to loop from i = 1 to n-1, and change last line to A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - A(i+1:n,i)*A(i,i+1:n) Finally, get for i=1 to n if A(i,i) zero and A(j,i) nonzero, swap rows i and j of L and A; record swap A(i+1:n,i) = A(i+1:n,i)/A(i,i) A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - A(i+1:n,i)*A(i,i+1:n) (draw picture of intermediate step) When done: U overwrites upper triangle and diagonal of A L (below diagonal) overwrites A (below diagonal) To see that this is the same Gaussian Elimination you learned long ago, start from for i=1 to n-1 ... for each column i for j=i+1 to n ... add a multiple of row i to row j to zero out ... entry (j,i) below diagonal m = A(j,i)/A(i,i) A(j,i:n) = A(j,i:n) - m*A(i,i:n) "Optimize" this by (1) not bothering to compute the entries below the diagonal you know are zero: change last line to A(j,i+1:n) = A(j,i+1:n) - m*A(i,i+1:n) (2) compute all the multipliers m first, store them in zeroed-out locations: for i=1 to n-1 for j=i+1 to n A(j,i) = A(j,i)/A(i,i) for j=i+1 to n A(j,i:n) = A(j,i:n) - A(j,i)*A(i,i:n) (3) combine loops into single expressions to get same as before: for i=1 to n-1 A(i+1:n,i) = A(i+1:n,i)/A(i,i) A(i+1:n,i+1:n) = A(i+1:n,i+1:n) - A(i+1:n,i)*A(i,i+1:n) The cost is sum_{i=1 to n-1} (n-i)^2 = (2/3)n^3 + O(n^2) multiplies and adds We need to do pivoting, even if we wouldn't divide by zero, to get an accurate answer: Suppose we run in single precision, 7 decimal digits A = [ 1e-8 1 ], inv(A) ~ [ -1 1 ] , [ 1 1 ] [ 1 -1e-8] cond(A) ~ 2.6, really small, so should get accurate answers L = [ 1 0 ] U = [ 1e-8 1 ] [ 1e8 1 ] [ 0 fl(1-1e8*1) = -1e8 ] but L*U = [ 1e-8 1 ] [ 1 0 ] (2,2) wrong, same answer for A(2,2)=-1, 0, +1... Answer L,U nearly independent of A(2,2): "numerical instability" L, U not the exact factors of a nearby problem Ex: Solve A*x = [1;2]; start by solving L*y = [1;2] so y(1) = fl(1/1) = 1 y(2) = fl(2 - 1e8*1) = -1e8, (nearly) independently of b(2), so would get same solution for many different b(2): can't all have small error! Another way to see why bad: cond(L), cond(U) = 1e8, much bigger than cond(A) If we permute so A(1,1)=1, then full accuracy. Intuition: we want to pick a large entry of A to be pivot A11 Error Analysis We want backward stability, ie A+E = P*L*U with E "small", ie norm(E)=O(macheps)*norm(A) Will see that we will need to avoid large values of L(j,i)*U(i,k), that are much larger than original entries of A: Turns out that by looking at algorithm, each entry of U and L computed mostly as a dot product, which we analyzed in Q 1.10: track where U(j,k) comes from: (picture): for j <= k, we compute fl[ A(j,k) - L(j,1)*U(1,k) - L(j,2)*U(2,k) - ... - L(j,j-1)*U(j-1,k) ] = U(j,k) apply analysis of dot product (Q1.10) to get U(j,k) = (1+e)*[ A(j,k) - fl(sum_{i=1:j-1} L(j,i)*U(i,k)) ] where |e| <= macheps = (1+e)*[ A(j,k) - sum_{i=1:j-1} L(j,i)*U(i,k)*(1+d(i)) ] where |d_{j,k}| <= (j-1)*macheps Rewrite this as A(j,k) = exact dot product + (small) error: A(j,k) = U(j,k)/(1+e) + sum_{i=1:j-1} L(j,i)*U(i,k)*(1+d(i)) ~ U(j,k)*(1-e) + sum_{i=1:j-1} L(j,i)*U(i,k)*(1+d(i)) = U(j,k) + sum_{i=1:j-1} L(j,i)*U(i,k) -e*U(j,k) + sum_{i=1:j-1} L(j,i)*U(i,k)*d(i) = (L*U)(j,k) + E(j,k) where |E(j,k)| = |-e*U(j,k) + sum_{i=1:j-1} L(j,i)*U(i,k)*d(i)| <= [ |U(j,k)| + sum_{i=1:j-1} |L(j,i)*U(i,k)| ]*(j-1)*macheps = (|L|*|U|)(j,k) * n*macheps Similar analysis applies when j > k to get L(j,k) = fl[ ( A(j,k) - L(j,1)*U(1,k) - ... - L(j,k-1)*U(k-1,k) )/U(k,k) ] (1+e1)(1+e2)( A(j,k) - fl(sum_{i=1:k-1} L(j,i)*U(i,k)) )/U(k,k) solve for A(j,k) to get A(j,k) = L(j,k)*U(k,k)/((1+e1)*(1+e2)) + fl(sum_{i=1:k-1} L(j,i)*U(i,k) ) = ... = (L*U)(j,k) + E(j,k) where |E(j,k)| <= (|L|*|U|)(j,k) * n*macheps again Altogether: A = L*U + E where |E| <= |L|*|U| * n*macheps componentwise Putting together whole solution (from Q1.11), since there are also error from triangular solves: Solving Ly=b => (L+dL)yhat = b where |dL| <= n*macheps*|L| Solving Ux=y => (U+dU)xhat = yhat where |dU| <= n*macheps*|U| Combine b = (L+dL)*yhat = (L*dL)*(U+dU) * xhat = (L*U + L*dU + dL*U + dL*dU) * xhat = (A - E + L*dU + dL*U + dL*dU ) * xhat = (A + dA)*xhat where |dA| <= |E| + |L*dU| + |dL*U| + |dL*dU| <= |E| + |L|*|dU| + |dL|*|U| + |dL|*|dU| = (3*n*macheps+O(macheps^2)) *|L|*|U| ~ 3*n*macheps * |L|*|U| Is this "backward stable"? I.e. is norm(dA) = O(macheps)*norm(A)? Need to compare norm(|L|*|U|) with norm(A) Def: "Pivot growth factor" = g = norm(|L|*|U|)/norm(A) Fact: g >= 1 (for most norms) Then ||dA|| <= 3*n*macheps*g* ||A|| so we can compare xhat from (A+dA)*xhat=b with true solution x from A*x=b: Thm: || xhat - x ||/||x|| <= 3*n*macheps*g*kappa(A) + O(macheps^2) Proof: from before, we bound || xhat - x ||/||x|| <= kappa(A) * ||dA||/||A|| Whether this is satisfactory depends on g: Bad example: A = [1e-8 1] from above: [ 1 1] |L|*|U| = [ 1 0 ] * [ 1e-8 1 ] = [ 1e-8 1 ] [ 1e8 1 ] [ 0 |-1e8| ] [ 1 2e8 ] and largest entry 2e8 is much bigger than A (g ~ 1e8): why we get the wrong answer Idea: Pick pivot A11 "large" so when we divide by it, entries of L are small: (1) Simplest, and standard, approach, used in most libraries: "Partial pivoting" (GEPP) Permute rows only so A11 largest available entry in column Then L21 = A21/A11 has |L21| <= 1 Theorem (easy): with GEPP, |L| <= 1 and max(|U(:,i)|) <= 2^(n-1)*max(|A(:,i)|) Bad news: worst case is terrible; even for n=24 in singular precision, all wrong Good news: hardly ever happens (only very small family of matrices where this occurs) Empirical observation, with some justification: gPP < n^(2/3) If all entries of matrix were "random", this would be true as you perform pivoting, they seem to get more random (2) Complete pivoting: permute rows and columns so that A11 largest entry in whole matrix, repeat at every step. Get A = P_r*L*U*P_c Theorem: gCP < n^(log n /4) Empirical: gCP < n^(1/2) Long-standing Conjecture: gCP < n (false, but nearly true) More expensive, hardly used, not in most libraries (3) Tournament pivoting - something new, needed to minimize communication (#messages): will present it later. (4) Threshold pivoting: this and similar schemes try to preserve sparsity while maintaining stability Finally, we note that 3*n*macheps*g*norm(A) is an upper bound, unlikely to be attained; see figures 2.1 and 2.2 for plots of actual backward error (which we know is ||E||/||A|| = ||A*xhat-b||/||A||*||xhat||) on random matrices; nearly always about macheps, i.e. doesn't grow with n. Final topic in error analysis of Ax=b: How to get a more accurate answer if the condition number is large Recall that error bound is roughly ||x_true - x_compute|| / ||x_true|| = cond(A) * pivot_growth * O(macheps) where we have put dependence on dimension inside the O(). So (1) how do we know if this is too large? Answer: can estimate cond(A) in O(n^2) extra work, after doing GEPP, pivot_growth also costs just O(n^2), to compute ||U||*||A|| (Plot expected error vs condition number on log scale) (2) what do we do about it? Answer 1: Do everything in higher precision - expensive! Answer 2: Do just O(n^2) extra work in higher precision: iterative refinement (much cheaper) Basically just Newton's method: Do GEPP to solve Ax=b, call initial solution x(1) i = 1 repeat r = A*x(i) - b ... in double precision, but just O(n^2) solve A*d = r ... using existing LU factors, just O(n^2) update x(i+1) = x(i) - d until "convergence" why compute r in double precision? otherwise computed residual r is mostly noise (though some benefit) testing "convergence" tricky: how can we avoid getting fooled by very ill-conditioned matrix that "accidentally" converges? see www.netlib.org/lapack/lawnspdf/lawn165.pdf for details see figures 33-35 of www.cs.berkeley.edu/~demmel/Future_ScaLAPACK_v7.ppt for results Available in LAPACK as sgesvxx, dgesvxx, etc Iterative refinement is interesting for different reasons on some Cell, GPUs, similar platforms: single precision *much* faster than double precision (sometimes 10x). So to get the "usual" accuracy in double do LU in single run iterative refinement, only computing r in double convergence test is simply "as good as running GEPP in double without refinement", i.e. norm(A*x-b) = O( epsilon )*norm(A)*norm(x) see www.netlib.org/lapack/lawnspdf/lawn177.pdf for details (on Cell processor) Where to find all this? Matlab: A\b or [P,L,U]=lu(A), rcond or condest LAPACK: xGETRF just for GEPP where x = S/D/C/Z xGESV to solve A*x=b xGESVXX for condition estimation, iterative refinement xGECON for condition estimation alone CLAPACK: similar ScaLAPACK: PxGETRF etc