(* Computation of branch cuts
 * This file implements some algorithms from my master's thesis, "Branch
 * Cuts in Computer Algebra".
 * Adam Dingle *)

Remove["Branch`*"]
Remove["Branch`Private`*"]

BeginPackage["Branch`", {"Plot`", "Useful`", "Algebra`ReIm`"}]

SimpComp::usage =
"SimpComp[f, g] is the functional composition of functions f and g,
algebraically simplified (evaluated)."

DecomposeFn::usage =
"DecomposeFn[f] decomposes f, a function of one argument, into
a composition of functions."

Cut::usage =
"A branch cut is reprsented by {Cut[x0, x1, f], h} where x0 and x1 are
real numbers, Infinity or -Infinity and f is a function from real to complex.
The cut is the range of f[x] as the real variable x ranges from x0 to x1.
h[x] is an alternate branch function for the cut."

SimpCut::usage =
"SimpCut[c] returns a list of cuts which represent a simplification of the
branch cut c."

Cuts::usage =
"Cuts[f] generates a list of the branch cuts of f."

RawCuts::usage =
"RawCuts[f] generates a list of the branch cuts of f, before simplification."

Branch::usage =
"Branch[f, r] generates a plot of the branch cuts of the function f
in the complex plane, where r designates the portion of the complex
plane to be viewed (both real and imaginary axes will range from -r to r)."

AnonFun::usage =
"AnonFun turns a function of the form Function[v, ...] into a function which
uses # to represent its argument."

Begin["`Private`"]

Attributes[Apply1] = {OneIdentity}
Default[Apply1] = Identity
Apply1[a_, b_] := a[b]

SimpComp[f_, g_] := Let[{c = f[g[#]]}, Function[c]]

Slot1 = Hold[#]

DTemplate[a_, Hold[b_]] := Fn[a[b]]

DecomposeFn[f_] :=
   Catch[
	(* Find all occurrences of # in f, and move up one level from each
         * occurrence. *)
	Let[{d = Map[Fn[p, If[p == {1}, Throw[f], Drop[p, -1]]],
                     Position[f, #]]},
	(* Find the expressions that are one level up from occurrences of #. *)
	  Let[{e = (Fn[p, HeldPosition[f, p]] /@ d)},
      If[d != {} (* # occurs at least once *)
	 && d != {{1}} (* not (one occurrence at top level) *)
         && (Unequal @@ d) (* all occurrences are distinct *)
         && (SameQ @@ e), (* all occurrences match *)
          DTemplate[DecomposeFn[SMapAtList[Fn[Slot1], f, d]], First[e]],
	  f]]]]

SolveV[e_, v_] :=
 (* Like Solve, but returns a list of values rather than a list of rules. *)
  MapR[Solve[e, v], Fn[r, v /. r]]

Discontinuities[Arg[e_]] := SolveV[e == 0, #]
Discontinuities[f_[g___]] := Union @@ (Discontinuities /@ {g})
  (* Finds all points at which the given expression of #, not involving
   * functions with branch cuts, is discontinuous or undefined. *)
Discontinuities[_] := {}

RealDiscontinuities[e_] :=
 Select[Discontinuities[e], IsReal]

PointIn[r1_, r2_] :=
 (* Returns some point in the real interval (r1, r2), where r1 < r2. *)
 Which[r1 === -Infinity && r2 === Infinity, 0,
        r1 === -Infinity, r2 - 1,
        r2 === Infinity, r1 + 1,
        True, (r1 + r2) / 2]

AndInterval[i_, l_] :=
  Select[MapR[l, Fn[li,
          {Max[li[[1]], i[[1]]], Min[li[[2]], i[[2]]]}]], #[[1]] <= #[[2]] &]
AndIntervals[l_, m_] :=
  Flatten1[MapR[l, Fn[i, AndInterval[i, m]]]]

SimpIntervals[{a:{a1_, a2_}, b:{b1_, b2_}, r___}] :=
 (* Collapse adjacent intervals in a sorted interval list. *)
 If[a2 == b1, SimpIntervals[{{a1, b2}, r}],
              Cons[a, SimpIntervals[{b, r}]]]
SimpIntervals[a:{_}] := a
SimpIntervals[{}] := {}

FindIntervals1[f_] :=
  Let[{g = f /. Less :> Equal /. Greater :> Equal
	     /. LessEqual :> Equal /. GreaterEqual :> Equal},
   Let[{p = Pairs[Sort[Join[{-Infinity, Infinity},
			    Flatten1[MapR[List @@ f, RealDiscontinuities]],
			    RealSolve[g, {#}]], NLess]]},
    SimpIntervals[Select[p, Fn[p2, N[f /. # -> PointIn[p2[[1]], p2[[2]]]]]]]
]]

FindIntervals[l_And] :=
 (* Returns a list of real intervals, in the form {low, high}, on which the
  * given inequality of the variable # is satisfied. *)
  Fold[AndIntervals, {{-Infinity, Infinity}}, MapR[l, FindIntervals]]
FindIntervals[f_] := FindIntervals1[f]

Infinite[i_] := i === Infinity || i === -Infinity

RealRange[e_, i_] :=
 (* Find the range of the given real expression of # on the given real
  * interval. *)
 Let[{extrema = Join[i, RealSolve[D[e, #] == 0, {#}]]},
  Let[{values = MapR[extrema, Fn[d,
		  If[Infinite[d], Limit[e, # -> d],
				  e /. # -> d]]]},
   {Min[values], Max[values]}
]]

(* These rules are stolen from Algebra/ReIm.m.  I don't want to load in
 * their whole package since I've heard it makes the system run tremendously
 * slowly. *)
(*
protected = Unprotect[Re, Im, Abs, Conjugate, Arg]
Re[x_Plus] := Re /@ x
Im[x_Plus] := Im /@ x

Re[x_ y_Plus] := Re[Expand[x y]]
Im[x_ y_Plus] := Im[Expand[x y]]

Re[x_ y_] := Re[x] Re[y] - Im[x] Im[y]
Im[x_ y_] := Re[x] Im[y] + Im[x] Re[y]
Protect[Evaluate[protected]]
*)

RealSolve[e_, v_] :=
 (* Solve the given equation(s), assuming that the variables being
  * solved for are real.
  * Limited, but handles some situations which Solve[] alone cannot. *)
 Let[{e1 = e /. Arg[x_] == y_ :> Im[x] == Re[x] Tan[y]
(*               /. i:Im[__] :> Distribute[i]
                 /. i:Re[__] :> Distribute[i]  *)
               /. Im[s_ /; MemberQ[v, s]] :> 0
               /. Re[s_ /; MemberQ[v, s]] :> s},
 Let[{trigvars = Select[v, Fn[var,
          Let[{trigpos = Position[e1, Sin[var] | Cos[var]]},
          Let[{trigvarpos = MapR[trigpos, Fn[p, Append[p, 1]]]},
	   Let[{pos = Position[e1, var]},
		Sort[trigvarpos] == Sort[pos]]]]]]},
 Let[{e2 = Join[e1, MapR[trigvars, Fn[var,
				Sin[var]^2 + Cos[var]^2 == 1]]]},
 Let[{v1 = Join[Complement[v, trigvars],
		MapL[trigvars, Fn[var, {Sin[var], Cos[var]}]]]},
  Let[{sols = Solve[e2, v1]},
   Let[{sols1 = MapR[sols, Fn[s,
	Let[{trig = MapR[trigvars, Fn[var,
			var -> ArcTan[Cos[var] /. s, Sin[var] /. s]]]},
	 v /. Join[s, trig]]]]},
   Select[sols1, Fn[s,
    And @@ MapR[s, IsReal]]
]]]]]]]

(* LinearTrans takes an expression, not a function, as its first argument,
 * so that the expression will be evaluated (so that, for example, a function
 * which evaluates to a sum or product will be expanded and matched).  This
 * is necessary for Sqrt to be matched as well. *)

AddOk[c_, v_] := IsConstant[v] || N[Arg[v] == Arg[c]] || N[Arg[v] == Arg[-c]]
LinearTrans[#+a_?IsConstant, v_, w_] := AddOk[a, v] && AddOk[a, w]
LinearTrans[# a_?IsConstant, _, _] := True
LinearTrans[#^a_?IsReal, t0_, t1_] := t0 === 0 || t1 === 0 ||
  (* Infinity == 0 does not reduce to false; hence we use === *)
				     N[Arg[t0] == Arg[t1]]
LinearTrans[___] := False

(*
Simp1[a_, b_, t0_, t1_, c_] :=
  If[LinearTrans[b, t0, t1], SimpCut[Cut[b&[t0], b&[t1], a]], c]
*)

RealExp[e_] :=
 (* Returns true if e is a real expression of #. *)
 FreeQ[e, Complex] && FreeQ[e, Log] && FreeQ[e, Power[_, n_ /; !IntegerQ[n]]]

Primitive[e_] :=
 (* Returns true if e is not a function invocation, though it may look like
  * one *)
 e ===  # || NumberQ[e]

(* Perform one branch cut simplification step *)
Simp1[Cut[r0_, r1_, f_]] := {Cut[r1, r0, f]} /; r0 > r1

Simp1[Cut[r0_, r1_, If[c_, f_, Null]&]] :=
  MapR[AndInterval[{r0, r1}, FindIntervals[c]],
       Fn[i, Cut[i[[1]], i[[2]], f&]]]

Simp1[Cut[r0_, r1_, e_?RealExp&]] :=
  Let[{r = RealRange[e, {r0, r1}]}, {Cut[r[[1]], r[[2]], Identity]}]

(* lines and powers *)
Simp1[Cut[r0_, r1_, (e_ #)^n_&]] := {Cut[r0, r1, e^n #^n&]}
Simp1[Cut[r0_, r1_, e_ #&]] /; !(N[0 <= Arg[e] < Pi]) :=
  {Cut[-r1, -r0, Evaluate[-e #]&]}
Simp1[Cut[r0_, r1_, #^p_Rational &]] /; r0 >= 0 && r1 >= 0 :=
  {Cut[r0^p, r1^p, Identity]}
Simp1[Cut[r0_, r1_, #^p_Rational &]] /; Arg[r0] != Arg[r1] :=
  {Cut[0, r0, #^p &], Cut[0, r1, #^p &]}
Simp1[Cut[r0_, r1_, #^p_Rational &]] /; r0 <= 0 && r1 <= 0 :=
  {Cut[-r0, -r1, (-1)^p #^p &]}

Simp1[c:___] := Null (* no simplification *)

Simp1b[c_] :=
 Let[{c1 = Simp1[c]}, If[c1 === Null, {c}, Flatten1[SimpCut /@ c1]]]

SimpCut[c:Cut[t0_, t1_, f_]] :=
  (* given a cut, return a list of cuts (possibly empty) to which it
   * simplifies. *)
  (* allow simplification inside expresssions *)
If[Primitive[DeFun[f]], Simp1b[c],
   Replace[DecomposeFn[f],
	      {a_[b_?(Fn[x, x =!= #])]& :>
		Let[{c1 = Simp1[Cut[t0, t1, b&]]},
		  If[c1 === Null, {c},
		    Flatten1[MapR[c1, Fn[d,
		      SimpCut[MapAtR[d, 3, Fn[g, Evaluate[a[DeFun[g]]]&]]]]
			]]]],
	       _ :> Simp1b[c]}]]

SimpCutFun[{c_, f_}] :=
 (* simplify a cut with its alternate function *)
 MapR[SimpCut[c], Fn[c1, {c1, f}]]

NewCut[a_, d_, b_, f_, h_] :=
 If[N[Arg[a-d] == Arg[d-b]], Cut[a, b, f, h], {}]

Join2[Cut[t0_, t1_, f_, h_], Cut[t2_, t3_, f_, h_]] :=
	Let[{p = {t0, t1, t2, t3}},
	  Let[{u = Union[p]},
	    If[Length[u] == 3,
	      Let[{d = First[Select[u, Count[p, #] > 1 &, 1]]},
               Replace[Complement[u, {d}], {{a_, b_} :>
		NewCut[a, d, b, f, h]}]],
	      {}]]]
Join2[_, _] := {}

JoinCuts[l_] := Squish[Join2, Sort[l, OrderedQ[{#1[[3]], #2[[3]]}]&]]

Pairs[l_] :=
 (* Given: list of elements
  * Return: list of adjacent pairs of elements in the list *)
 MapR[Range[Length[l] - 1], Fn[i, {l[[i]], l[[i + 1]]}]]

Overlap[cuts_ (* (cut, altfun) list list *) ] :=
(* returns: (cut, (source cut, altfun in source, source fun #) list) list *)

(* annotate each cut with the # it is from *)
 Let[{tag = MapR[Range[Length[cuts]], Fn[n,
             MapR[cuts[[n]], Fn[c, {c[[1]], c[[2]], n}]]]]},
  Let[{fcuts = Flatten1[tag]},  (* (cut, altfun, source #) list *)
   Let[{c = CollectSame[fcuts, #[[1]][[3]]&]},
    Let[{m = MapR[c, Fn[l,
              Let[{endpoints = Flatten1[MapR[l,
					 Fn[n,{n[[1]][[1]], n[[1]][[2]]}]]]},
               Let[{s = Sort[Union[endpoints], Less]},
		Let[{p = Pairs[s]},
                 Let[{r = MapR[p, Fn[p1,
			   {Cut[p1[[1]], p1[[2]], l[[1]][[1]][[3]]],
			    Select[l, #[[1]][[1]] <= p1[[1]] &&
				       #[[1]][[2]] >= p1[[2]] &]}]]},
                  Select[r, #[[2]] =!= {}&]]]]]]]},
     Flatten1[m]
]]]]

Trim[f_, cutfuns_] :=
(* receives: (cut, altfun) list
 * returns: list in which altfuns equivalent to f have been removed *)
 Select[cutfuns, Fn[cf, cf[[2]] =!= f]]
		  
(* Here is the code that calculates parametric branch cuts.
 * We use the following two rules.
 * If Cut[g] = (a0, a1, m) with altfun b and f has no branch cuts,
 * then Cut[f[g]] = (a0, a1, m) with altfun (f o b).
 * If Cut[f] = (a0, a1, m) with altfun b and g has no branch cuts,
 * then Cut[f[g]] = (a0, a1, Inv[g] o m) with altfun (b o g). *)

RawCuts[sourcefun:Fn[f_[g___]]] :=
(* receives: anonymous function *)
(* returns: (cut, altfun) list *)

(* First, compute the branch cuts of the subexpressions *)
  Let[{subFun = MapR[{g}, Fn[g1, g1&]]},
   Let[{subCuts = MapR[subFun, Cuts]},  (* (cut, altfun) list list *)

(* Now compute the cuts which are mapped onto f's branch cuts.
 * We are assuming that primitive functions have branch cuts only in the
 * first argument. *)
    Let[{mapcut = Fn[{cutfun, inv},
	 	      {MapAtR[cutfun[[1]], 3, SimpComp[inv, #]&],
		       cutfun[[2]]}]},
     Let[{ginv = Inverses[subFun[[1]]]},
      Let[{fcuts1 = Flatten1[MapR[Cuts[f], Fn[cutfun,  (* (cut, altfun) list *)
                             MapR[ginv, Fn[inv,
                              mapcut[cutfun, inv]]]]]]},
       Let[{fcuts = Flatten1[MapR[fcuts1, SimpCutFun]]},

(* Join cuts which overlap *)
       Let[{ovCuts = Overlap[Prepend[subCuts, fcuts]]},
        Let[{altCut = Fn[sources,
             (* given sources of cut: (source cut, altfun, # of source) list *)
             (* return altfun for the cut *)
                    Let[{h = MapR[Range[Length[subFun] + 1], Fn[i,
                      Let[{s = Select1[sources, #[[3]] == i&]},
			If[s === Null,  (* doesn't derive from this source *)
                            If[i == 1, f, subFun[[i - 1]]],
                            s[[2]]]]]]},
		     Fn[Evaluate[Apply[h[[1]], MapR[Rest[h], DeFun]]]]]]},
         Let[{altCuts = MapR[ovCuts, Fn[c,
		         MapAtR[c, 2, altCut]]]},
(* Eliminate removable branch cuts *)
          Trim[sourcefun, altCuts]
]]]]]]]]]

BranchCuts[f_] := MapR[Cuts[f], First]

Cuts[f_] :=
 (* Given: function
  * Returns: (cut, altfun) list *)
 Cuts1[EvalFun[AnonFun[f]]]
  (* evalfun: convert to canonical form so function comparisons will work *)

Cuts1[e_?Primitive&] := {}
Cuts1[Log] := {{Cut[-Infinity, 0, #&], Log[#] - 2 Pi I &}}
(* there is some redundancy here - the power case should really cover the
 * square root, but sqrt is not always expanded to a power *)
(* actually, should no longer be necessary - will always be expanded *)
(* Cuts1[Sqrt] := {{Cut[-Infinity, 0, #&], -Sqrt[#]&}} *)
(* assumption: two-variable primitives only have branch cut in first argument*)
Cuts1[Power] := {{Cut[-Infinity, 0, Identity], #1^#2 Exp[- 2 Pi I #2]&}}
(* Cuts1[c:Fn[f_[g_]]] := JoinCuts[SimpCut /@ RawCuts[c]] *)
Cuts1[c:Fn[f_[g___]]] := RawCuts[c]
(* ... = MapR[RawCuts[c], Fn[cf, MapAtR[cf, 1, SimpCut]]] *)
(* Any other functions are assumed to have no branch cuts. *)
Cuts1[___] := {}

Regions[f_] :=
 (* Given a function, return a list of regions into which its branch cuts
  * divide the complex plane. *)
 ComplexPartition[BranchCuts[f]]

Endpoints[Cut[r0_, r1_, f_]] := {f[r0], f[r1]}
FiniteEndpoints[cut_] := Select[Endpoints[cut], IsConstant]

CircleCut[radius_] :=
  Cut[-Pi, Pi, radius Exp[I #]&]

FindEnclosing[cuts_] :=
(* Find a shape that contains all the branch cut endpoints, and intersects
 * each branch cut as many times as that cut has endpoints at infinity.
 * Current implementation: a large circle.
 * We really should check to see if the circle is a valid enclosing shape. *)
  Let[{radius = 1000}, CircleCut[radius]]

TraverseRegion[e_, c_, s_, endpoints_] :=
 (* Given an ending point, a current point, a current segment,
  * and a list of segments for each endpoint, traverse a region and return it.
  * A region is a list of pairs {e, c}, where e is an endpoint and c is
  * a branch cut. *)
 Join[{c, s[[2]]},
      If[s[[1]] === e, {},
	Let[{l = Lookup[endpoints, s[[1]]]},
	 Let[{se = Select1[l, Fn[l1, l1[[1]] === c]]},
	  Let[{p = Position[l, se, 1][[1]][[1]]},
           Let[{p1 = If[p == Length[l], 1, p + 1]},
	    TraverseRegion[e, s[[1]], l[[p1]], endpoints]
]]]]]]

DirectAway[cut_, point_] :=
 If[point === Endpoint[cut, 1], cut, ReverseCut[cut]]

Traverse[end_, current_, endpoints_] :=
 Cons[current,
  If[Endpoint[current, 2] == end, {},
   Let[{e = Endpoint[current, 2]},
   Let[{l = Lookup[endpoints, Endpoint[current, 2]]},
    Let[{p = Index[l, CanonicalCut[current]]},
     Let[{p1 = If[p == Length[l], 1, p + 1]},
      Traverse[end, DirectAway[l[[p1]], e], endpoints]
]]]]]]

Regions[t_, endpoints_] :=
 (* Given t, a list of segments,
  * find the set of regions that are surrounded by the segments in t. *)
 If[t === {}, {},
  Let[{start = First[t]},
   Let[{region = Traverse[Endpoint[start, 1], start, t, endpoints]},
    Cons[region, Regions[Complement[t, region], endpoints]]]]]

EliminateIntersection[cuts_] :=
(* Given a set of branch cuts, eliminate all intersections by breaking
 * branch cuts into smaller pieces.
 * Returns a list of pairs {original cut, list of pieces}. *)
 MapR[cuts, Fn[c,
  Let[{s = SortR[NLess, Union[MapL[cuts, Fn[d,
   	If[c === d, {c[[1]], c[[2]]},
    	 MapR[Intersect[c, d], Fn[k, k[[1]]]]]]]]]},
  {c, MapR[Pairs[s], Fn[p, Cut[p[[1]], p[[2]], c[[3]]]]]}]]]

Endpoint[cut_, n_] := cut[[3]][cut[[n]]]

FiniteCut[cut_] := IsNumber[Endpoint[cut, 1]] &&
		    IsNumber[Endpoint[cut, 2]]

ReverseCut[cut_] := Cut[cut[[2]], cut[[1]], cut[[3]]]

WhichEndpoint[cut_, endpoint_] :=
 Index[Endpoints[cut], endpoint]

Angle[cut_, endpoint_] :=
(* Return the angle at which the cut leaves the given endpoint. *)
 Let[{w = WhichEndpoint[cut, endpoint]},
  Arg[(D[DeFun[cut[[3]]], #] /. # :> cut[[w]]) *
	Sign[cut[[3 - w]] - cut[[w]]]]
]

ComplexPartition[cuts_] :=
  Let[{enclose = FindEnclosing[cuts]},
   Let[{cuts1 = EliminateIntersection[Cons[enclose, cuts]]},
    Let[{icuts = Flatten1[MapR[cuts1, Fn[c, c[[2]]]]]},
(* Construct a set of segments which form boundaries of regions.
 * A segment is a directed cut. *)
     Let[{segments = Flatten1[MapR[icuts, Fn[c,
      If[LookupRlist[cuts1, c] === enclose, {c},
          If[FiniteCut[c], {c, ReverseCut[c]}, {}]]]]]},
(* Construct an endpoint table used for looking up planar orderings. *)
      Let[{cutmap = Map[icuts, Fn[c],
	    {c, If[FiniteCut[c], {Endpoint[c, 1], Endpoint[c, 2]}, {}]}]},
       Let[{endpointmap = InvertAssoc[cutmap]},
       Let[{endpoints = MapR[endpointmap, Fn[e,
	{e[[1]], Sort[e[[2]], Fn[{a, b},
			Angle[a, e[[1]]] < Angle[b, e[[1]]]]]}]]},
       Let[{regions = Regions[segments, endpoints]},
(* Delete the segments of the enclosing shape, and
 * replace endpoints on the enclosing shape with endpoints at infinity. *)
       Let[{e = Lookup[cuts1, enclose]},
       Let[{ep = Union[Flatten1[MapR[e, Endpoints]]]},
        Let[{trimmed = MapR[regions, Fn[r,
	  Flatten1[MapR[r, Fn[s,
	   If[MemberQ[e, s], {}, 
		FixEndpoint[FixEndpoint[s, 1, e], 2, e]]]]]]]},
	 trimmed
]]]]]]]]]]]

FixEndpoint[cut_, endpoint_, infinite_] :=
 If[MemberQ[Endpoint[cut, endpoint], infinite],
  ReplacePart[cut, Infinity * Sign[cut[[endpoint]] - cut[[3 - endpoint]]]]]

Branch[f_, r_:10] :=
  Show @@ Append[
   Function[b, Module[{t},
    TrimComplex[b[[3]][t], {t, b[[1]], b[[2]]}, r, DisplayFunction->Identity]]]
      /@ Cut[f], DisplayFunction->$DisplayFunction]

Intersect[Cut[r0_, r1_, f_], Cut[s0_, s1_, g_]] :=
(* Given two branch cuts, return a list of points at which they intersect.
 * Each point has the form {t1, t2}, where t1 and t2 are the parametric
 * coordinates of the intersection point in cut1 and cut2, respectively. *)
 Module[{r, s},
  Let[{p = RealSolve[{Re[f[r]] == Re[g[s]], Im[f[r]] == Im[g[s]]}, {r, s}]},
   Select[p, N[r0 <= #[[1]] <= r1] && N[s0 <= #[[2]] <= s1]&]]]

Off[Function::flpar]  (* bug workaround *)
AnonFun[Function[v_, body_]] := Function[Evaluate[body /. v -> #]]
AnonFun[x_] := x
On[Function::flpar]

End[]
EndPackage[]