;; Automatic Differentiation code for Common Lisp ;; Richard Fateman, November, 2005 ;; This is all provided in the context of a Generic Arithmetic Package. ;; Package based in part on code posted on comp.lang.functional newsgroup by ;; Ingvar Mattsson 09 Oct 2003 (defpackage :ga ;generic arithmetic (:shadow "+" "-" "/" "*" "expt" ;binary arith "=" "/=" ">" "<" "<=" ">=" ;binary comparisons "sin" "cos" "tan" ;... more trig "atan" "asin" "acos" ;... more inverse trig "sinh" "cosh" "tanh" "atanh" ;... more hyperbolic "expt" "log" "exp" "sqrt" ;... more exponential, powers "1-" "1+" "abs" ) (:use :common-lisp)) (in-package :ga) ;; df structure for f,d: f is function value, and d derivative, default 0 (defstruct (df (:constructor df (f &optional (d 0)))) f d ) ;; print df structures with < , > (defmethod print-object ((a df) stream)(format stream "<~a, ~a>" (df-f a)(df-d a))) ;;function ARITHMETIC-IDENTITY: When fed an operator and a non-nil ;;argument, it returns a value for unary application. What does (+ a) mean? ;;A nil arg means there were NO operands. What does (+ ) mean. ;;It is used only by defarithmetic, which in turn helps ;; us to write out + * - / of arbitrary number of args. (defmacro arithmetic-identity (op arg) `(case ,op (+ (or ,arg 0)) (- (if ,arg (two-arg-* -1 ,arg) 0)) (* (or ,arg 1)) (/ (or ,arg (error "/ given no arguments"))) (expt (or ,arg (error "expt given no arguments"))) (otherwise nil))) ;binary comparisons? (defun tocl(n) ; get corresponding name in cl-user package (find-symbol (symbol-name n) :cl-user)) (defmacro defarithmetic (op) (let ((two-arg (intern (concatenate 'string "two-arg-" (symbol-name op)) :ga )) (cl-op (tocl op))) `(progn (defun ,op (&rest args) (cond ((null args) (arithmetic-identity ',op nil)) ((null (cdr args))(arithmetic-identity ',op (car args))) (t (reduce (function ,two-arg) (cdr args) :initial-value (car args))))) (defgeneric ,two-arg (arg1 arg2)) (defmethod ,two-arg ((arg1 number) (arg2 number)) (,cl-op arg1 arg2)) (compile ',two-arg) (compile ',op) ',op))) (defarithmetic +) ;; defines some of + programs. See below for more (defarithmetic -) (defarithmetic *) (defarithmetic /) (defarithmetic expt) ;; defcomparison helps us generate numeric comparisons: 2 args ;; and n-arg. CL requires they be monotonic. ;; That is in Lisp, (> 3 2 1) is true. (defun monotone (op a rest)(or (null rest) (and (funcall op a (car rest)) (monotone op (car rest)(cdr rest))))) (defmacro defcomparison (op) (let ((two-arg (intern (concatenate 'string "two-arg-" (symbol-name op)) :ga )) (cl-op (tocl op))) `(progn (defun ,op (&rest args) (cond ((null args) (error "~s wanted at least 1 arg" ',op)) ((null (cdr args)) t) ;; one arg e.g. (> x) is true (t (monotone (function ,two-arg) (car args) (cdr args))))) (defgeneric ,two-arg (arg1 arg2)) (defmethod ,two-arg ((arg1 number) (arg2 number)) (,cl-op arg1 arg2)) (defmethod ,two-arg ((arg1 df) (arg2 df)) (,cl-op (df-f arg1)(df-f arg2))) (defmethod ,two-arg ((arg1 number) (arg2 df))(,cl-op arg1(df-f arg2))) (defmethod ,two-arg ((arg1 df) (arg2 number))(,cl-op (df-f arg1) arg2 )) (compile ',two-arg) (compile ',op) ',op))) (defcomparison >) ;;provides ALL the comparison methods (defcomparison =) (defcomparison /=) (defcomparison <) (defcomparison <=) (defcomparison >=) ;; that's all ;; extra + methods specific to df (defmethod ga::two-arg-+ ((a df) (b df)) (df (cl:+ (df-f a)(df-f b)) (cl:+ (df-d a)(df-d b)))) (defmethod ga::two-arg-+ ((b df)(a number)) (df (cl:+ a (df-f b)) (df-d b))) (defmethod ga::two-arg-+ ((a number)(b df)) (df (cl:+ a (df-f b)) (df-d b))) ;;extra - methods (defmethod ga::two-arg-- ((a df) (b df)) (df (cl:- (df-f a)(df-f b)) (cl:- (df-d a)(df-d b)))) (defmethod ga::two-arg-- ((b df)(a number)) (df (cl:- (df-f b) a) (df-d b))) (defmethod ga::two-arg-- ((a number)(b df)) (df (cl:- a (df-f b)) (df-d (cl:- b)))) ;;extra * methods (defmethod ga::two-arg-* ((a df) (b df)) (df (cl:* (df-f a)(df-f b)) (cl:+ (cl:* (df-d a) (df-f b)) (cl:* (df-d b) (df-f a))))) (defmethod ga::two-arg-* ((b df)(a number)) (df (cl:* a (df-f b)) (cl:* a (df-d b)))) (defmethod ga::two-arg-* ((a number) (b df)) (df (cl:* a (df-f b)) (cl:* a (df-d b)))) ;; extra divide methods (defmethod ga::two-arg-/ ((u df) (v df)) (df (cl:/ (df-f u)(df-f v)) (cl:/ (cl:+ (cl:* -1 (df-f u)(df-d v)) (cl:* (df-f v)(df-d u))) (cl:* (df-f v)(df-f v))))) (defmethod ga::two-arg-/ ((u number) (v df)) (df (cl:/ u (df-f v)) (cl:/ (cl:* -1 (df-f u)(df-d v)) (cl:* (df-f v)(df-f v))))) (defmethod ga::two-arg-/ ((u df) (v number)) (df (cl:/ (df-f u) v) (cl:/ (df-d u) v))) ;; extra expt methods (defmethod ga::two-arg-expt ((u df) (v number)) (df (cl:expt (df-f u) v) (cl:* v (cl:expt (df-f u) (cl:1- v)) (df-d u)))) (defmethod ga::two-arg-expt ((u df) (v df)) (let* ((z (cl:expt (df-f u) (df-f v))) ;;z=u^v (w ;; u(x)^v(x)*(dv*log(u(x))+du*v(x)/u(x)) ;; = z*(dv*log(u(x))+du*v(x)/u(x)) (cl:* z (cl:+ (cl:* (cl:log (df-f u)) ;log(u) (df-d v)) ;dv (cl:/ (cl:* (df-f v)(df-d u)) ;v*du/ u (df-f u)))))) (df z w))) (defmethod ga::two-arg-expt ((u number) (v df)) (let* ((z (cl:expt u (df-f v))) ;;z=u^v (w ;; z*(dv*LOG(u(x)) (cl:* z (cl:* (cl:log u) ;log(u) (df-d v))))) (df z w))) ;; A rule to define rules, a new method for df, the old method for numbers (defmacro r (op s) `(progn (defmethod ,op ((a df)) (df (,(tocl op) (df-f a)) (,(tocl '*) (df-d a) ,(subst '(df-f a) 'x s)))) (defmethod ,op ((a number)) (,(tocl op) a)))) ;; Add rules for every built-in numeric program. ;; Must insert the name in the shadow list too. ;; This is just a sampler. (r sin (cos x)) ;; provides EVERYTHING ADIL needs about sin (r cos (* -1 (sin x))) (r asin (expt (+ 1 (* -1 (expt x 2))) -1/2)) (r acos (* -1 (expt (+ 1 (* -1 (expt x 2))) -1/2))) (r atan (expt (+ 1 (expt x 2)) -1)) (r sinh (cosh x)) (r cosh (sinh x)) (r tanh (expt (cosh x) -2)) (r atanh (expt (1+ (* -1 (expt x 2))) -1)) (r log (expt x -1)) (r exp (exp x)) (r sqrt (* 1/2 (expt x -1/2))) (r 1- 1) (r 1+ 1) (r abs x);; hm does this matter? (defun re-intern(s p) ;; move expression to :ga package (cond ((or (null s)(numberp s)) s) ((symbolp s)(intern (symbol-name s) p)) (t(cl-user::cons (re-intern (car s) p) (re-intern (cdr s) p))))) (defun fact(x) (if (= x 1) (df 1 0.422784335098d0) (* x (fact (1- x))))) ;; in :ga (setf one (df 1 1)) ;; in :ga (dotimes (i 10)(print (fact (+ i one)))) (defun s(x) (if (< (abs x) 1.0d-5) x (let ((z (s (* -1/3 x)))) (-(* 4 (expt z 3)) (* 3 z))))) ;; newton iteration (ni fun guess) ;; usage: fun is a function of one arg. ;; guess is an estimate of solution of fun(x)=0 ;; output: a new guess. (Not a df structure, just a number (defun ni (f z) ;one newton step (let* ((pt (if (df-p z) z (df z 1))) ; make sure init point is a df (v (funcall f pt))) ;compute f, f' (df-f (- pt (/ (df-f v)(df-d v)))))) (defun ni2 (f z) ;one newton step, give more info (let* ((pt (if (df-p z) z (df z 1))) (v (funcall f pt))) ;compute f, f' at pt (format t "~%v = ~s" v) (values (df-f (- pt (/ (df-f v)(df-d v)))) ;the next guess v))) (defun run-newt1(f guess &optional (count 18)) ;; Solve f=0 ;; Look only at the guesses. ;; though if the derivative goes to 0, we are stuck ;; in the newton step. (let ((guesses (list guess)) (reltol #.(* 100 double-float-epsilon)) (abstol #.(* 100 least-positive-double-float))) (dotimes (i 6)(push (ni f (car guesses)) guesses)) ;; make 6 iterations. (incf count -6) (cond ((< (abs (car guesses)) abstol) (car guesses)) ((< (abs(/ (-(car guesses)(cadr guesses))(car guesses))) reltol) (car guesses)) ((<= count 0) (format t "~%newt1 failed to converge; guess =~s" (car guesses)) (car guesses)) (t (run-newt1 f (car guesses) count))))) (defun run-newt2(f guess &key (abstol 1.0d-8) (count 18)) ;; Solve f=0 ;; Looks only at the residual. (dotimes (i count (format t "~%Newton iteration not convergent after ~s iterations: ~s" count guess)) (multiple-value-bind (newguess v) (ni2 f guess) (if (< (abs (df-f v)) abstol) (return newguess) (setf guess newguess))))) ;;try (run-newt2 'sin 3.0d0) ;; dcomp, AD compiler ;; Richard Fateman 11/2005 (defvar *difprog* nil);; the text of the program (defvar *bindings* nil) ;; bindings used for program (defun dcomp (f x p) ;; the main program (let ((*difprog* nil) (*bindings* nil)) (multiple-value-bind (val dif) (dcomp1 f x p) (emit `(values ,val ,dif)) `(let ,*bindings* ,(format nil "~s wrt ~s" f x) ,@(nreverse *difprog*))))) ;; Assist in compilation of a function f and its derivative at a point x=p ;; dcomp1 changes *bindings* and *difprog* as it munches on the expression f. (defun dcomp1 (f x p) (declare (special *v*)) (cond((atom f) (cond ((eq f x) (values p 1.0d0)) ((numberp f)(values f 0.0d0)) (t (let ((r (make-dif-name f x))) (push r *bindings*) (emit `(setf ,r 0.0d0)); unnecessary? already bound to 0.0d0 (values f r))))) (t (let* ((op (first f)) (program (get op 'dopcomp)));otherwise, look up the operator's d/dx (if program (funcall program (cdr f) x p) (error "~% dcomp cannot do ~s" f);; for now ))))) (defun gentempb(r) (push (gentemp r) *bindings*) (car *bindings*)) (defun emit(item)(unless (and (eq (car item) 'setf) (eq (cadr item)(caddr item))) (push item *difprog*))) ;; simple unary f(x) programs. ;;We define them all with the same template. (defmacro dr(op s) (setf(get op 'dopcomp) ;;define a rule for returning v, d for dcomp (dr2 op (treefloat s)))) (defun dr2 (op s) ;; a rule to make rules for dcomp `(lambda(l x p) (cond ((cdr l) (error "~&too many arguments to ~s: ~s" ',op l)) (t (let ((v (gentempb 'f)) (d (gentempb 't))) (multiple-value-bind (theval thedif) (dcomp1 (car l) x p) (emit (list 'setf v (list ',op theval))) (emit (list 'setf d (*chk thedif (subst theval 'x ',s))))) (values v d)))))) (defun treefloat(x) ;; convert all numbers to double-floats (cond ((null x) nil) ((numberp x)(coerce x 'double-float)) ((atom x) x) (t (mapcar #'treefloat x)))) ;; Here's how the rule definition program works. ;; Set up rules. (dr tan (power (cos x) -2)) (dr sin (cos x)) (dr cos (* -1 (sin x))) (dr asin (power (+ 1 (* -1 (power x 2))) -1/2)) (dr asin (power (+ 1 (* -1 (power x 2))) -1/2)) (dr acos (* -1 (power (+ 1 (* -1 (power x 2))) -1/2))) (dr atan (power (+ 1 (power x 2)) -1)) (dr sinh (cosh x)) (dr cosh (sinh x)) (dr log (power x -1)) (dr exp (exp x)) (dr sqrt (* 1/2 (power x -1/2))) ;; etc etc ;; Next, functions of several arguments depending on x ;; These include + - * / expt, power (setf (get '* 'dopcomp) '*rule) (setf (get '+ 'dopcomp) '+rule) (setf (get '/ 'dopcomp) '/rule) (setf (get '- 'dopcomp) '-rule) (setf (get 'power 'dopcomp) 'power-rule) (setf (get 'expt 'dopcomp) 'expt-rule) (defun +rule (l x p) (let ((valname (gentempb 'f)) (difname (gentempb t))) (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,difname ,d ,valname ,v))) (dolist (i (cdr l))(multiple-value-bind (v d) (dcomp1 i x p) (emit `(setf ,difname ,(+chk d difname))) (emit `(setf ,valname ,(+chk v valname))))) (values valname difname))) (defun -rule (l x p) (let ((valname (gentempb 'f)) (difname (gentempb t))) (cond ((cdr l) (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,difname ,d ,valname ,v))) (dolist (i (cdr l))(multiple-value-bind (v d) (dcomp1 i x p) (emit `(setf ,difname (- ,difname ,d))) (emit `(setf ,valname (- ,valname ,v ))))) ) ;; just one arg (t (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,difname (- ,d) ,valname (- ,v)))))) (values valname difname))) ;; (- x y z) means (+ x (* -1 y) (* -1 z)). (- x) means (* -1 x) (defun *rule (l x p) (let ((valname (gentempb 'f)) (difname (gentempb t))) (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,difname ,d ,valname ,v))) (dolist (i (cdr l))(multiple-value-bind (v d) (dcomp1 i x p) (emit `(setf ,difname ,(+chk (*chk v difname) (*chk d valname)))) (emit `(setf ,valname ,(*chk v valname))))) (values valname difname))) (defun /rule (l x p) (let ((vname (gentempb 'f)) (dname (gentempb t))) (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,dname ,d ,vname ,v))) (multiple-value-bind (v d) (dcomp1 (cadr l) x p) (emit `(setf ,dname (/ (- (* ,v ,dname)(* ,vname ,d)) (* ,v ,v)) ,vname (/ ,vname ,v)))) (values vname dname))) (defun power-rule (l x p) (cond ((cddr l) (error "~&too many arguments to power: ~s" l)) ;; now we assume that everything is a-ok ;; i.e. power has only two arguments, the second argument ;; which is the power to be raised to is independent of x ;; so we can use for sqrt, cube-root etc. (t (let ((vname (gentempb 'f)) (dname (gentempb t))) (multiple-value-bind (v d) (dcomp1 (car l) x p) ;;; We could use +chk and *chk in the following emissions ;;; but they are really inessential. (emit `(setf ,vname (power ,v ,(cadr l)))) (emit `(setf ,dname (* ,d (power ,v (1- ,(cadr l))) ,(cadr l))))) (values vname dname))))) (defun expt-rule (l x p) "this is the general power rule where the exponent could be an arbitrary function of x" (cond ((cddr l) (error "~&too many arguments to expt : ~s" l)) ((numberp (cadr l))(power-rule l x p)) ;; we had z = (expt f g) ;; z' = z*(g*log f)' = z*(g*f'/f + g'*log f) (t (let ((vname (gentempb 'f)) (dname (gentempb t))) (multiple-value-bind (v d) (dcomp1 (cadr l) x p) (emit `(setf ,vname ,v ,dname ,d))) (multiple-value-bind (v d) (dcomp1 (car l) x p) (emit `(setf ,dname ,(+chk `(/ ,(*chk vname d) ,v) (*chk dname `(log ,v))))) (emit `(setf ,vname (expt ,v ,vname))) (emit `(setf ,dname ,(*chk vname dname)))) (values vname dname))))) ;; the next two programs, "optimize": ;; generated code so as to not add 0 or mult by 0 or 1 (defun +chk (a b)(cond ((and (numberp a)(= a 0)) b) ((and (numberp b)(= b 0)) a) (t `(+ ,a ,b)))) (defun *chk (a b)(cond ((and (numberp a)(= a 1)) b) ((and (numberp b)(= b 1)) a) ((and (numberp a)(= a 0)) a) ((and (numberp b)(= b 0)) b) (t `(* ,a ,b)))) ;;;; this is the main program for compiling (defun dc (f x &optional (otherargs nil)) ;; produce a program, p(v) ready to go into the compiler to ;; compute f(v), f'(v), returning result as a structure, a df (let ((*difprog* nil) (*bindings* nil) (*v* (gentemp "g"))) (declare (special *v*)) (multiple-value-bind (val dif) (dcomp1 f x *v*) (emit `(df ,val ,dif)) `(lambda (,*v* ,@otherargs) ,(format nil "~s wrt ~s" f x) ;;; comment out these declares if you prefer v's type to be unknown (declare (double-float ,*v*)) (declare (optimize (speed 3)(debug 0)(safety 0))) ; (assert (typep ,*v* 'double-float)) (let ,(mapcar #'(lambda (r)(list r 0d0)) *bindings*) (declare (double-float ,@*bindings*)) ,@(nreverse *difprog*)))))) ;; A plausible way to use dc is as follows: (defmacro defdiff (name arglist body) ;; put the pieces together (progn (setf (get name 'defdiff) name) (let ((r (dc body (car arglist) (cdr arglist)))) ;; returns a df. `(defun ,name , (cadr r) ,@(cddr r))))) ;; usage (defdiff f (x) (* x (sin x))) ;; Then you can call (f 3.0d0) ;; a few more pieces to allow inside defdiff: if, progn, setf. ;; we could try for a few more. Oh, > < = etc. work "automagically." (defun if-rule (l x p) (let ((valname (gentempb 'f)) (difname (gentempb t))) (emit `(multiple-value-setq (,valname ,difname) ;; assume only variables, not derivatives in condition (if ,(subst p x (car l)) (funcall (function ,(dc (cadr l) x)) ,p) (funcall (function ,(dc (caddr l) x)) ,p)))) (values valname difname))) (defun progn-rule(l x p) (mapc #'(lambda (k)(dcomp1 k x p)) (butlast l)) (dcomp1 (car (last l)) x p)) (defun setf-rule(l x p) (if (not(symbolp (car l)))(error "dcomp cannot do setf ~s" (car l))) (multiple-value-bind (v d) (dcomp1 (cadr l) x p) (emit `(setf ,(car l) ,v)) (let ((dname (make-dif-name (car l) x))) (push dname *bindings*) (emit `(setf ,dname ,d))) (values v d))) (defun make-dif-name(s x) ;s is a symbol. make a new one like s_dif_x (intern (concatenate 'string (symbol-name s) "_DIF_" (symbol-name x)))) (setf (get 'if 'dopcomp) 'if-rule) (setf (get 'progn 'dopcomp) 'progn-rule) (setf (get 'setf 'dopcomp) 'setf-rule) ;; derivative by complex evaluation.. ;;returns function and derivative of f at point x, sortof (defun fd(f x) (let ((c (funcall f (complex x 1.0d-8)))) (values (realpart c)(*(imagpart c) 1.0d8)))) (defun niX(f x &optional (h 1.0d-8)) (let ((c (funcall f (complex x h)))) (- x (* h(/ (realpart c) (imagpart c) ))))) (in-package :ga)