Reflections on prysm, part 04: Orthogonal polynomials
This is the fourth in a series of posts that take a look at my python library, prysm. See The first post for more introduction.
This one will focus on working with polynomials, especially Zernike polynomials, with the library.
As a preface, prysm has the fastest code of any physical optics program I’ve seen for working with the Zernike polynomials, and the only code for Qbfs and Qcon polynomials from Greg Forbes. This post is not for wont of speed.
Rather, the way prysm achieved that performance is dramatically overcomplicated.
Before we get into the code, I should discuss briefly how prysm handles these polynomials differently.
Some programs (and historically, prysm too) use explicitly written out equations for the Zernike polynomials. This is uncomplicated and reasonably performant, though the code is verbose. This looks something like:
def defocus(rho, phi):
"""Zernike defocus."""
return 2 * rho**2 - 1
def primary_astigmatism_00(rho, phi):
"""Zernike primary astigmatism 0°."""
return rho**2 * np.cos(2 * phi)
def primary_astigmatism_45(rho, phi):
"""Zernike primary astigmatism 45°."""
return rho**2 * np.sin(2 * phi)
# etc
I think this way is often popular because you do not need to know any advanced mathematics to do this. The downside is that you have to known the closed form solution for the polynomials a priori, which limits how high of an order you can go to. There are also subtle rounding errors later in the series, although those are more academic than anything else. The leading coefficients are not nice scalars for the terms that have large powers on rho. Textbooks and papers show nice scalars, but that is a small rounding error.
The use of large powers means you can only expand as high as about the 20th radial power due to numerical instability. That, too, may be a bit academic.
prysm contains the first few Zernike polynomials written out this way in part as a historical carryover for compatibility, and in part because they are very convenient to use.
Many programs use Rodrigues’ radial polynomial form to calculate the Zernike polynomials:
$$ R_n^m (\rho) = \sum_{k=0}^{\frac{n-m}{2}} \frac{(-1)^k (n-k)!}{k!(\frac{n+m}{2}-k)!(\frac{n-m}{2}-k)!}\rho^{n-2k} $$
A benefit to this form is that it works for arbitrary (n,m). I haven’t written the azimuthal term for brevity. Can you see what is “wrong” with this form? There are two things:
It involves many powers, some of which are large for large n
It involves many factorials, which are slow
“large” n is right around n=20. I submit that 20th order spherical aberration, the lowest Noll or Fringe indexed mode to have n=20, is an academic topic, as it represents the balancing of 15th order spherical aberration with 13th order, 11th order, and so on. I do not think even the most high NA microscope or lithography system contains those aberrations. For this reason, the numerical instability is an academic topic among skilled users, since they would not use high order polynomials. However, we see many users who do not know better design lenses with extreme high order aspheres or even Zernike surfaces of very high order. Perhaps not academic after all.
The speed is a separate topic. If we motivate ourselves with physical optics problems that boil down to:
- Compute an amplitude (e.g. circle) over the pupil
- Compute a phase (sum of Zernikes) over the pupil
- Convert that amplitude and phase to a complex field
- Propagate it with a Fourier transform
- Take the modulus squared to compute an intensity
These five steps are completely separable, and we will naively assume at the outset that none are particularly faster or slower than others. Often, this naivety – and checking assumptions – leads to discovery.
Thinking with an array size of 256x256 at the pupil and 512x512 at the PSF – achieved by zero padding – we can time the elements of this, and get something along the lines of:
- 1.2 ms
- (tbd)
- 0.2 ms
- 1.5 ms (merged with 5)
These are not particularly large lengths of time. number 2 is TBD because there are multiple ways to consider this calculation in a benchmark:
- an evaluation for a single zernike term
- a sum of several zernike terms
We know from the Rodrigues radial polynomial way that it will get slower for larger (n,m) due to the sum running for longer. We also know the explicit function way will run slower for larger (n,m) because the expressions get bigger.
So let’s time #2 for two particular cases, a single term of moderate order (say, n=m=4), and a sum of the first 25 Noll Zernike terms, which strikes a middle ground between small sums, which are common, and 36-term sums, which are also common. I have never implemented the Rodrigues way myself, so I’ll use POPPY:
from poppy import zernike
# the single term case
%timeit zernike.zernike(4,4,256)
3.48 ms ± 29.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
## the sum case
def zsum(nterm):
n,m = zernike.noll_indices(1)
out = zernike.zernike(n,m,256)
for i in range(2,nterm+1):
n, m = zernike.noll_indices(1)
out += zernike.zernike(n,m,256)
return out
%timeit zsum(25)
43 ms ± 941 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
So, if we tally up the two cases we get:
- single, sparse Zernike term, 6.4 ms (54% Zernike computation)
- 25 term sum of Zernikes, 45.9 ms (93.6% Zernike computation)
This makes acceleration of Zernike calculations offer massive performance gains to physical optics programs. prysm caches everything heavily, to the extent that steps 1..3 above take a total of about half a millisecond. So how does it do it?
In a phrase, harder math.
prysm uses a mathematical insight I picked up from the brilliant mind of Greg Forbes; the Zernike polynomials are a special case of the Jacobi polynomials, and can be computed quickly with a recurrence relation.
Without getting into the details of the jacobi polynomials themselves, suppose you have:
def jacobi(n,alpha,beta,x): pass
Then to compute R as in the Rodrigues formula, you simply need to do:
def zernike_r(n,m,rho):
return jacobi((n-m)/2, 0, abs(m), rho^2 - 1)
This is much simpler, but it does kick the can on where the computational complexity might be to the jacobi function. The thing about the jacobi polynomials is that the recurrence relation reduces to:
$$ a \cdot P_n^{(\alpha,\beta)} = b \cdot x \cdot P_{n-1}^{(\alpha,\beta)} - c \cdot P_{n-2}^{(\alpha,\beta)} $$
This looks like two multiples, a divide, and a subtract. The various Pn are a few more operations to compute, but there are only multiplies and adds. What’s more, only the b term involves the argument x, so we know the work for computing a or c is done once for the any length x as long as n, alpha, and beta do not change.
For the last year or so, prysm has used this link to compute its zernike polynomials. The result is that the timing looks like:
from prysm import NollZernike
from prysm.zernike import zcachemn
# clear blows the cache
%timeit NollZernike(Z14=1); zcachemn.clear()
2.57 ms ± 360 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def zsum2(n):
out = NollZernike(np.ones(n), samples=256)
zcachemn.clear()
return out
%timeit zsum2(25)
32.1 ms ± 3.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
A savings of a third isn’t bad, but didn’t I promise more? What’s going on?
To understand why the performance here is not as good as it could be, we need to understand a little bit about how to actually implement that jacobi function and the rest of the work needed to turn a jacobi polynomial evaluation into a Zernike polynomial.
The jacobi polynomials are computed via the recurrence relation given above. Recurrence in programming means a function that calls itself a finite number of times.
The jacobi polynomial implementation in prysm looks like:
def a(n, alpha, beta, x):
term1 = 2 * n
term2 = n + alpha + beta
term3 = 2 * n + alpha + beta - 2
return term1 * term2 * term3
def b(n, alpha, beta, x):
term1 = 2 * n + alpha + beta - 1
iterm1 = 2 * n + alpha + beta
iterm2 = 2 * n + alpha + beta - 2
iterm3 = alpha ** 2 - beta ** 2
temp_product = iterm1 * iterm2 * x + iterm3
return term1 * temp_product
def c(n, alpha, beta, x):
term1 = 2 * (n + alpha - 1)
term2 = (n + beta - 1)
term3 = (2 * n + alpha + beta)
return term1 * term2 * term3
def jacobi(n, alpha, beta, x, Pnm1=None, Pnm2=None):
if n == 0:
return e.ones_like(x)
elif n == 1:
term1 = alpha + 1
term2 = alpha + beta + 2
term3 = (x - 1) / 2
return term1 + term2 * term3
elif n > 1:
if Pnm1 is None:
Pnm1 = jacobi(n-1, alpha, beta, x)
if Pnm2 is None:
Pnm2 = jacobi(n-2, alpha, beta, x)
a_ = a(n, alpha, beta, x)
b_ = b(n, alpha, beta, x)
c_ = c(n, alpha, beta, x)
term1 = b_ * Pnm1
term2 = c_ * Pnm2
tmp = term1 - term2
return tmp / a_
This is pretty straightforward. If you look towards the bottom at jacobi, you see the elif n > 1
block contains two checks to see if it was given the previous term in the recurrence relation, and if not it sources it itself.
The result of this is that if you call jacobi with Pnm1 and Pnm2 provided, it has linear time complexity w.r.t. n
. If you do not, it is quadratic. I am telling you this, so it’s a good guess that I am smart enough to provide them. When I implemeted this, I was not smart enough to see there was a way to do this without recurrence at all - we’ll get to that.
I mentioned that prysm <3 caching. The result of that is that the one-line wrapper around jacobi to make a zernike (or at least the radial part) was wrapped in a checks notes 325 line front-end class that does the caching. That 325 line abomination is why prysm doesn’t eke out more performance. While it does avoid all of the duplicate work, it replaces it with 4n^2 class attribute and dict lookups. That doesn’t sound so bad, you think, but those are not free in python:
d = {1:2}
p = NollZernike()
%timeit d[1]
31.7 ns ± 0.0994 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
# assume insert time is ~= 1.5x access time
%timeit p.samples_x
201 ns ± 1.56 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
So for n=6, we have 4*6^2 = 144 lookups. We also have 3n insertions, having their own 6n lookups. All of that churn inside the class eats the majority of the time (more than 1.5ms for the single term).
A further (large) speedup was left on the table by the inferior implementation in prysm.
If we switch gears to Julia for a moment, we can explore how fast this could go if we were smarter. We do this because we have some manner of control over what is computed for each element of rho and what is computed only once.
function abc(n, α, β, x)
# the parens emphasize the terms, not PEMDAS
a = (2n) * (n + α + β) * (2n + α + β - 2)
b1 = (2n + α + β -1)
b2 = (2n + α + β)
b2 = b2 * (b2 - 2)
b = b1 * (b2 * x + α^2 - β^2)
c = 2 * (n + α - 1) * (n + β - 1) * (2n + α + β)
return a, b, c
end
function jacobi(n, α, β, x)
if n == 0
return 1.
elseif n == 1
return (α + 1) + (α + β + 2) * ((x-1)/2)
end
# this uses a loop to avoid recursion
# we init P of n-2, n-1, and n=2.
# the latter is the first recursive term.
# then loop from 3 up to n, updating
# Pnm2, Pnm1, a, b, c, Pn
Pnm2 = 1. # == jacobi(1, ...)
Pnm1 = (α + 1) + (α + β + 2) * ((x-1)/2) # == jacobi(2, ...)
a, b, c = abc(2, α, β, x)
Pn = ((b * Pnm1) - (c * Pnm2)) / a
if n == 2
return Pn
end
for i = 3:n
Pnm2 = Pnm1
Pnm1 = Pn
a, b, c = abc(n, α, β, x)
Pn = ((b * Pnm1) - (c * Pnm2)) / a
end
return Pn
end
function zernike(n, m, ρ, θ; norm::Bool=true)
x = ρ.^2 .- 1
n_j = (n - m) ÷ 2
am = abs(m)
# α=0, β=|m|
# there is a second syntax where you have x reversed, 1 - ρ^2,
# in which ase you swap α and β. It makes no difference
out = jacobi(n_j, 0, am, x)
if m != 0
if m < 0
out *= (ρ.^am .* sin(m*θ))
else
out *= (ρ.^am .* cos(m*θ))
end
end
if norm
out *= zernike_norm(n,m)
end
return out
end
Note especially that we use a loop, so there is no recursion. This turns something that would be O(n^2) and makes it O(n) time complexity. For bonus points, it does not make the interface worse.
This is a simple implementation that is about as clear as the code can be without putting in performance optimizations. Very notable, this computes a and c for each element of x unnecessarily. The code is around 100x faster with changes that avoid that. I will use an optimized version below, although another one that is about three times faster is possible (I am still working on it).
Out of the door, we open with a comparison:
r=collect(range(-1., 1., length=256^2))
@btime zernike(4, 0, $r, 0);
178.974 μs (9 allocations: 2.00 MiB)
This is slightly not equivalent, but if you’ll excuse me that, we can see that this runs in .18 ms, around 10x faster than the implementation in prysm.
There is more on the table, though. You can see that the recursive implementation that uses a loop simply writes over old variables as the higher orders are computed. When you compute a sequence of values, you actually compute zernike(2,0) and so on as a part of computing zernike(4,0). In other words, computing a set of polynomials (or a sum) should be only about as expensive as computing the highest order term. It follows, then, that a set that spans up to ZNoll 25 could be computed in, drum roll about
@btime zernike(6, -4, $r, 0);
1.970 ms (18 allocations: 3.50 MiB)
(or even about 3x faster than this). This is a 20x performance delta. It moves the needle on a forward model taking 46 sec to a forward model taking 6 seconds. Stated a bit differently, this moves a model that takes 1 minute to compute to one that takes less than 10 seconds.
This is a much faster parameter change-run-evaluate loop, and results in more productive engineering.
Where are the lessons learned?
Perhaps the most striking feature of the Julia code shown here is that it is simple. It does not use any meta machinery to obfuscate the math. There is no 325 line class doing any caching, just a simple for loop.
Not shown in the Julia code is that there is a faster way, in general, to run the jacobi function. It involves using @avx
and other tricks, though. Also not shown is that there should be other signatures for jacobi_sequence
, zernike_sequence
, etc, that fill out an array during the loop instead of overwriting. There should also be jacobi_sum
and zernike_sum
, etc, which do not fill out an array in the same sense, but instead perform +=
on a work array internally. Those would be optimally efficient, and deliver on the promised* 2ms for a sum of 25 Zernikes, still with no caching, and lower memory usage.
prysm would be better served by removing zcachemn outright and adopting something more similar than this.
The code would be massively simpler.
It would run faster.
It would be more direct.
It would be more readable and intelligible.
It would be more maintainable.