After my previous post on the fundamental theorem of arithmetic, I had intended to move on to modular arithmetic, but I ended up getting pulled down a primality-testing rabbit hole. That ended up involving a return to polynomial long division, except this time in Clojure.
The AKS primality test
In the quest to discover a better method for testing the primality of very large numbers, I came across the AKS primality test, which is relatively new and has won many awards.
This “unconditional deterministic polynomial-time algorithm that determines whether an input number is prime or composite” is of great theoretical significance and the algorithm has “only” 5 steps, so I thought it would be worth trying to implement. Plus, when I saw the phrase multiplicative order, I thought, “Hey, I know what that is” and got excited.
The original paper has a pretty bold, simple title: PRIMES is in P, but good luck understanding the contents.
The essence of the algorithm—which I have absolutely zero intuition for—is as follows. Given an integer \(n > 1\):
- If \(n\) is a perfect power (i.e., \(n = a^b\) for any \(a, b \in ℤ\)), return composite.
- Find the smallest \(r\) coprime to \(n\) such that \(\textrm{ord}_r(n) > (\log_2 n)^2\).
- If \(n\) is divisible by \(a\) for any \(2 \leq a \leq \min (r, n - 1)\), return composite.
- If \(n \leq r\), return prime.
- If \((X + a)^n \not= X^n + a \mod (X^r - 1, n)\) for any \(1 \leq a \leq \lfloor \sqrt{\varphi(r)} \log_2(n) \rfloor\), return composite. Else return prime.
Perfect powers
I knew I was in for a bit of a slog when I tried to implement the first step of the algorithm.
What is the best way to determine if a number \(n = a^b\) for some \(a, b \in ℤ\)? You could brute-force it, sure. But primality tests are usually concerned with really large numbers, so trying every possible combination of \(a, b\) might take a lifetime or two.
I decided to try a technique I learned from doing many, many Project Euler problems: flip the problem on its side and solve that problem instead.
Since we don’t need to know the value of \(a\) and are only concerned with whether or not \(a \in ℤ\) (i.e., it is a whole number), we can test the \(b\)th roots of \(n\) up to some upper bound.
Highest possible power?
First problem: How do you determine the upper bound? The answer is surprisingly simple.
\(a \in ℤ\), but \(a\) cannot be \(1\) because \(1^k = 1\) for any \(k\). So \(2\) is the lower bound for \(a\), and will necessarily have the highest possible \(b\), because if \(2^b = n\), then \(3^b > n\).
We can thus find the highest possible \(b\) by solving \(2^b = n\) for \(b\), or in other words, \(\log_2(n)\).
Since all composite powers can be broken down into prime powers, we only need to look at prime powers up to and including \(b\). Of course, since this is in the context of a primality-testing algorithm, it would be silly to use a prime sieve here. Testing \(2\) and then all odd \(b\) up to \(\log_2(n)\) is sufficient.
Floating-point fun
Second problem: When dealing with large numbers, floating-point errors can trick the computer into thinking a number is an integer when it shouldn’t be. For example, despite the obvious fact that \(\sqrt{n}^2 = n\), the following returns false
:
(let [n 923849028343489238498324908928349028490829058N])
(= n (Math/pow (Math/sqrt n) 2))
;; false
This is because Java’s built-in math functions return Double
s. What we need here is a way to determine the \(n\)th root of a number to an arbitrary degree of precision.
I wrote about the Babylonian square root algorithm previously, and it can be adapted to compute \(n\)th roots without too much trouble. The square root algorithm \(t := \frac{t + \frac{x}{t}}{2}\) is actually:
So the \(n\)th root algorithm is:
To maintain a little more control over the intermediate values, we can use reduce
d multiplication ((reduce * (repeat exp) base)
) instead of (Math/pow base exp)
, and reduce
d addition instead of multiplication.
The result of that looks like this:
(defn naive-bab-nth-root
[n root]
(loop [t 1.0 iters 0]
(if (>= iters 100) t
(let [ts (repeat (- root 1) t)]
(recur (-> (/ n (reduce * ts))
(+ (reduce + ts))
(/ root))
(inc iters))))))
However, the sticking point is the type, which is set by the initial approximation t
. Roots of integers are generally either integers or irrational numbers, so you can’t just slap an M
on it (coerce it to BigDecimal
) and call it a day:
Execution error (ArithmeticException) at java.math.BigDecimal/divide (BigDecimal.java:1690).
Non-terminating decimal expansion; no exact representable decimal result.
You have to specify some level of precision (decimal places).
Let’s think about this a bit more. This algorithm returns increasingly accurate values with each successive iteration. But due to floating-point rounding errors, the algorithm might never converge on a true whole number.
This means that replacing the current stopping condition (based on iterations, (>= iters 100)
) with one based on a simple heuristic like (= (Math/floor t) t)
or (= (bigint t) t)
will produce lots of false negatives (for small inputs) and false positives (for large inputs). No bueno.
Epsilon to the rescue
My first encounter with floating-point rounding errors happened in the context of training a machine learning algorithm (detailed in the link above). Can we apply the same logic here?
If we define an error threshold \(\epsilon\), then we can use that in our stopping condition. Let’s start with using a ridiculously small BigDecimal
value for our threshold. This should surely give us a whole number when the answer \(\sqrt[b]{n} = a\) is actually a whole number—or at least enough zeroes after the decimal point that we can declare with confidence that it is a whole number.
(defn more-precise-bab-nth-root
[n root]
(let [eps 0.000000000000000000000000000000000000000001M]
(loop [t 1.0M iters 0]
(let [ts (repeat (- root 1) t)
nxt (with-precision 100 (-> (/ n (reduce *' ts))
(+ (reduce +' ts))
(/ root)))]
(if (or (>= iters 100)
(< (.abs (- nxt t)) eps)) t
(recur nxt (inc iters)))))))
(defn naive-bigdec-nth-root
[n root]
(let [eps 0.000000000000000000000000000000000000000001M]
(loop [t 1.0M]
(let [ts (repeat (- root 1) t)
nxt (with-precision 100 (-> (/ n (reduce *' ts))
(+ (reduce +' ts))
(/ root)))]
(if (< (.abs (- nxt t)) eps) t
(recur nxt))))))
(require '[clojure.numeric-tower :as tower]')
(defn babylonian-root
"High-precision BigDecimal nth-root using the Babylonian algorithm,
with a close initial approximation for ridiculously fast convergence."
[n root]
(let [eps 0.000000000000000000000000000000000000000001M]
(loop [t (bigdec (tower/expt n (/ root)))] ;; rough initial approx
(let [ts (repeat (- root 1) t)
nxt (with-precision 100 (-> (/ n (reduce *' ts))
(+ (reduce +' ts))
(/ root)))]
(if (< (.abs (- nxt t)) eps) t
(recur nxt))))))
(babylonian-root 92709463147897837085761925410587 67)
;; 3.000000000000000000000000000000000000000000000000000000000034073599999999794044017777778109595749154M
(Note, the apostrophes in *'
and +'
allow values to silently overflow from long
to BigInteger
if necessary.)
In 1. Fast but bad, I used both iterations and epsilon as stopping conditions, thinking that a perfect power would stop the loop early and speed up computation. It did, but at the expense of accuracy. Even-numbered roots returned really strange outputs (unrealistically large numbers). This is similar to gradient descent overshooting the minimum of the cost function in training a neural network because the learning rate is too high.
That means that in some cases, more than 100 iterations were required for the algorithm to converge. So my next attempt was 2. Better but slow, in which I removed the iterations as a stopping condition, and let the algorithm run until the difference fell under the error threshold.
That worked, but it was slow. Try using it to calculate \(\sqrt[67]{92709463147897837085761925410587}\)—it’s not very fast. Since this is just part of step one of a very complex algorithm, it should ideally happen in an instant.
Then it occurred to me that the algorithm should not take so many iterations to converge. To be exact, the number of iterations required is a direct consequence of how accurate the initial approximation is. This matters less for small inputs, but makes a colossal difference with large inputs.
There is probably a formula to determine a good initial guess without resorting to floating-point arithmetic, but doing so turned out to be a very fast and good enough solution nonetheless. Simply express \(\sqrt[b]{n}\) as \(n^{\frac{1}{b}}\).
In Clojure, that looks like (tower/expt n (/ root))
. (I prefer using tower/expt
, but you could certainly use built-in Java interop to do this with Math/pow
instead.) And voilà! We have arrived at 3. Best. It converges instantly and returns a highly accurate result.
Testing wholeness
As we intended, the values returned by babylonian-root
have a ridiculous amount of decimal places, and we can safely assume that a result like
(scroll right ======>>>>>)
3.000000000000000000000000000000000000000000000000000000000034073599999999794044017777778109595749154M
is, with 99.9% certainty, a whole number. But there are still non-zero digits after the decimal point at some point, which means simply using (= x (Math/floor x))
won’t work here either.
The best way to be sure if our \(n\)th root is an integer is to Math/floor
it, then raise it to the \(n\)th power and see if it equals our input value. Both values must be coerced to the same type in order to test equality, and coercing to bigint
ensures that small and large values will be accommodated equally well. (Using *'
will only cause values to overflow to bigint
if they are large enough.)
(defn nth-root-is-integer?
"Tests if the nth root of x is an integer in the mathematical
(not programming) sense—i.e., if it is a whole number.)."
[x n]
(let [floor (bigint (Math/floor (babylonian-root x n)))
exp (bigint (reduce *' (repeat n floor)))]
(= x exp)))
Testing perfect power-ness
Finally, we’re ready to bring it all together. To test for a perfect power, we iterate through all possible \(b\) up to the upper bound we established earlier and return true
, terminating early, if any \(\sqrt[b]{n} = \lfloor \sqrt[b]{n} \rfloor\). Otherwise, return false.
(defn perfect-power? [n]
(let [max-power (/ (Math/log n) (Math/log 2))
powers (cons 2 (filter odd? (range 3 (inc max-power))))]
(some (partial nth-root-is-integer? n) powers)))
In Clojure, it is more idiomatic to use some
, although this will technically return the first non-nil
value rather than true
.
Phew! That was a lot of work just for a minor part of the algorithm. Let’s move on.
\(r\) u ready?
The next few steps of the algorithm revolve around some number \(r\) that satisfies certain criteria.
The first step draws on a concept called multiplicative order, which comes from modular arithmetic. I don’t want to cover it in full detail here, as I’m planning to write future posts on modular arithmetic in this series, so I will present just the functions and a simple explanation here.
After we find \(r\), the next steps are pretty straightforward.
Multiplicative order (in brief)
It’s similar to the concept of multiplicative inverse, which I have covered before, although instead of looking for a number that is congruent to \(1\) after multiplication, we are looking for a number that is congruent to \(1\) after exponentiation.
Implementing a multiplicative order function in code takes a few steps. If I were to explain them in full here, this blog post would become overwhelming, so I am just going to leave the code here and revisit modular arithmetic more thoroughly in future posts. I’ll provide a link here when those posts are ready.
(defn mod-pow
;; Adapted from https://en.wikipedia.org/wiki/Modular_exponentiation
"Quickly calculates a ^ b % m. Useful when a and b are very large integers."
[base exp m]
(if (or (= m 1) (= base m)) 0
(loop [base (mod base m) e exp res 1]
(if (zero? e) res
(recur (-> (*' base base) (mod m))
(bit-shift-right e 1)
(if (odd? e)
(-> (*' res base) (mod m))
res))))))
(defn powers-of-a-mod-n
"a^k (mod n) for all 0 ≦ k < n, where k ∈ ℤ."
[a n]
(map #(mod-pow a % n) (range n)))
(defn multiplicative-order
"ord_n(a), the smallest positive integer k such that a^k ≡ 1 (mod n) where
a is coprime to n. a^0 ≡ 1 (mod n) for any n, so the quick way to find k
is to count the distinct powers of a (mod n)."
[a n]
(when (coprime? a n)
(count (distinct (powers-of-a-mod-n a n)))))
\(\textrm{ord}_r(n)\) is read “the multiplicative order of \(n \pmod r\)”. In the math notation, the modulus comes first, but when you read it aloud, the modulus comes last, so I have written the function with the arguments in the latter order.
Euler’s totient (\(\varphi\) phi) function
The (very complicated) last step of the algorithm requires first calculating \(\varphi(r)\), which is the number of integers less than \(r\) that are coprime to it. That is, their greatest common divisor is \(1\) (\(4\) and \(9\) are coprime to each other, for example.)
For any prime number \(p\), \(\varphi(p)\) is \(p - 1\). Of course, since we don’t know if our target number is prime, that doesn’t help us much in this case.
There is another way to optimize the totient function using Euler’s product rule, but because that requires factorizing the number, that is also of little help here.
So, let’s proceed with the most naïve, simplistic version of the function.
(defn naive-phi
"Naive version of Euler's totient function that only uses gcd, since
the optimized version requires factoring n first."
[n]
(count (filter #(= 1 (tower/gcd n %)) (range 1 n))))
The crazy polynomial part
Now we come to the most formidable part of the algorithm. It took a while to even understand the notation at first. We have to find
The capital letter \(X\) is apparently a convention from abstract algebra. This way of notating polynomials reflects the fact that we are not concerned with actually filling in the value of \(x\) in a given polynomial, as if it were a function; it is just a symbol.
When I first saw this, the “double modulus” was the part I found most confusing. I’m still not sure of the correct terminology for what is going on here, but essentially, this is what we need to compare:
- Left-hand side: Take the remainder of \(\frac{\textcolor{#1f77b4}{(X + a)^n}}{\textcolor{mediumpurple}{X^r - 1}}\), then reduce all coefficients \(\mod \textcolor{orange}{n}\).
- Right-hand side: Take the remainder of \(\frac{\textcolor{#e377c2}{X^n + a}}{\textcolor{mediumpurple}{X^r - 1}}\), then reduce all coefficients \(\mod \textcolor{orange}{n}\).
This requires polynomial long division. Brace yourself.
Polynomials in Clojure
Thanks to A Programmer’s Introduction to Mathematics, I was already familiar with the convention of representing polynomials in code as an array of coefficients, where the index of a coefficient represents the power of its associated term. For example, \(x^3 - 2x^2 + 17\) would be [17, 0, -2, 1]
.
Trying to port this exact representation from Python (or Java, or some other C-type language) to Clojure got a bit messy as soon as I tried to implement an addition function, because Clojure does not have a built-in map-longest
function, or a clean way of writing one. Some helpful folks on the awesome Clojurians Slack group advised me to try representing polynomials as a map instead, with the powers as keywords: {0 17, 2 -2, 3 1}
.
In addition to being easier to manipulate in Clojure, this way of organizing the data has the added bonus of being order-agnostic and not being sensitive to omitted zero coefficients.
Well, in order to ensure that the zero coefficients don’t matter way or the other, we should write a “trim” function to save us from possibly pulling out our hair:
(defn poly-trim-
"Removes terms with zero coefficients from a polynomial."
[pnml]
(->> (filter #(zero? (get pnml %)) (keys pnml))
(reduce #(dissoc %1 %2) pnml)))
It’s an intuitively simple idea that requires a slightly roundabout functional implementation. Get the keys of the map (powers of the polynomial) whose coefficient is zero, then successively dissoc
(remove) those key-val pairs from the map.
Low-hanging fruit
Let’s nail the low-hanging fruit first. Reducing the coefficients of a polynomial \(\mod n\) is pretty easy:
(defn poly-mod
"Reduces the terms of a given polynomial mod n."
[pnml n]
(zipmap (keys pnml) (map #(mod % m) (vals pnml))))
Next, the right-hand-side term: \(\textcolor{#e377c2}{X^n + a}\). Since n
is the initial input to the algorithm and a
will be taken from a range, we can treat them as constants. We can thus write this as {n 1, 0 a}
.
Then, the polynomial modulus: \(\textcolor{mediumpurple}{X^r - 1}\). r
will have been defined in a previous step, so we can treat it as a constant. We can thus write this as {r 1, 0 -1}
.
That leaves two (rather tedious) things: Expanding \(\textcolor{#1f77b4}{(X + a)^n}\), which should be done using modular exponentiation to prevent the coefficients from exploding; and implementing polynomial long division.
Modular exponentiation of polynomials
Exponentiation requires multiplication, which requires addition.
Representing polynomials as hash-maps allows us to add polynomials extremely concisely using merge-with
:
(defn add
"Adds two polynomials."
[p1 p2]
(poly-trim- (merge-with + p1 p2)))
To multiply polynomials, we take every pair of terms with non-zero coefficients, add their powers, and multiply their coefficients to obtain the new terms. If there are multiple terms with the same power, add those coefficients.
(defn mul
"Multiplies two polynomials."
[p1 p2]
(->> (for [powers1 (keys p1) powers2 (keys p2)]
{(+ powers1 powers2)
(* (get p1 powers1) (get p2 powers2))})
(reduce #(merge-with + %1 %2) {})
poly-trim-))
To exponentiate polynomials, just reduce
!
(defn exp
"Exponentiation of a polynomial, [p(x)]^e."
[pnml e]
(reduce mul (repeat e pnml)))
However, in this case, we need to exponentiate our polynomial \(\mod n\), so in order to keep the coefficients from exploding (because the values will quickly become astronomical, especially when testing very large numbers), let’s reduce the coefficients \(\mod n\) after every multiplication.
(defn mod-exp
"Slightly faster version of [p(x)]^e (mod m), where p(x) is a polynomial.
Reduces the result of each multiplication mod m with every iteration, rather
than only once at the end, in order to keep the intermediate coefficients
from exploding."
[pnml e m]
(reduce #(poly-mod (mul %1 %2) m) (repeat e pnml)))
Polynomial remainder (shortcut)
Interestingly, the particular polynomial divisor \(\textcolor{mediumpurple}{X^r - 1}\) used in this algorithm seems to have some kind of special property related to cyclotomic polynomials. I have no idea what those are, but someone who implemented the algorithm in JavaScript has found an interesting shortcut to finding the remainder after dividing any polynomial by that divisor.
I don’t know why it works, but I was able to implement it pretty quickly in Clojure. Check out the Medium post linked above for an explanation of how it works.
(defn quick-poly-rem
"Shortcut to finding the remainder of p(x) / (x^r - 1)."
[pnml r]
(poly-trim- (reduce-kv (fn [res power coeff]
(merge-with + res {(mod power r) coeff}))
{} pnml)))
Polynomial long division
However, I tend not to be satisfied with these hand-wavy magic tricks (or at least, it seems hand-wavy because the author of that article didn’t provide the source of the “trick”). So I decided to roll up my sleeves and implement polynomial long division from scratch (ugh).
Having done it once in Python, I was not looking forward to doing it again.
First, we need to implement subtraction, because otherwise the intermediate steps of the division will get messy (because of the LISP syntax). We have multiplication already, so this is easy. Just multiply the polynomial by {0 -1}
, a constant of \(-1\).
(defn sub
"Subtracts polynomial p2(x) from polynomial p1(x)."
[p1 p2]
(let [neg-p2 (mul p2 {0 -1})]
(poly-trim- (add p1 neg-p2))))
From there, it’s just a matter of translating the algorithm into Clojure. I found it easier to read and translate the long-hand procedure from English prose than translating pre-cooked implementations in other programming languages.
(defn degree
"Finds the degree (power of highest-power term) of a polynomial."
[pnml]
(apply max (keys pnml)))
(defn lc
"Leading coefficient (coefficient of highest-power term) of a polynomial."
[pnml]
(get pnml (degree pnml)))
(defn poly-quot-
"The quotient of a polynomial p1(x) divided by another p2(x).
Returns nil if p2 is of a higher degree than p1."
[p1 p2]
(let [d1 (degree p1) d2 (degree p2)]
(when (>= d1 d2)
{(- d1 d2) ;; power = difference in degree
(/ (lc p1) (lc p2))}))) ;; coeff = quotient of lc's
(defn div
"Polynomial long division of p1(x) / p2(x).
Returns nil if p2 is of a higher degree than p1."
;; http://www.math.ucla.edu/~radko/circles/lib/data/Handout-358-436.pdf
[p1 p2]
(when-let [q (poly-quot- p1 p2)] ;; sanity check
(loop [qs q r (->> q (mul p2) (sub p1))]
(if (empty? r) {:quotient qs :remainder r} ;; divides evenly
(if-let [new-q (poly-quot- r p2)]
(recur (conj qs new-q) ;; divides with remainder
(->> new-q (mul p2) (sub r)))
{:quotient qs :remainder r}))))) ;; can't divide anymore
(defn poly-rem
"Remainder after dividing two polynomials, p1(x) / p2(x)."
[p1 p2]
(:remainder (div p1 p2)))
Prime time
Finally, it’s time to combine all of the above into an implementation of the AKS algorithm. To reiterate, the algorithm is as follows:
- If \(n\) is a perfect power (i.e., \(n = a^b\) for any \(a, b \in ℤ\)), return composite.
- Find the smallest \(r\) coprime to \(n\) such that \(\textrm{ord}_r(n) > (\log_2 n)^2\).
- If \(n\) is divisible by \(a\) for any \(2 \leq a \leq \min (r, n - 1)\), return composite.
- If \(n \leq r\), return prime.
- If \((X + a)^n \not= X^n + a \mod (X^r - 1, n)\) for any \(1 \leq a \leq \lfloor \sqrt{\varphi(r)} \log_2(n) \rfloor\), return composite. Else return prime.
Taking advantage of Clojure’s when
to keep things concise, here is my implementation, which returns true
if the input is prime, nil
if the input is proven composite before the final step, and false
if the input is proven composite in the last step.
On my local version of this code, I’ve split the modular arithmetic and polynomial functions into different namespaces to keep things tidy. tower
, ma
, h
, and p
denote the namespaces numeric-tower
, modular-arithmetic
, helpers
, and polynomial
respectively.
(defn aks-prime?
"Uses the Agrawal–Kayal–Saxena primality test to determine if an integer n
is prime. Returns true if prime, nil or false otherwise."
[n]
;; 1. Check if n is a perfect power
(when-not (perfect-power? n)
;; 2. Find the smallest r such that ord_r(n) > (log_2 n)^2.
(let [log (-> (Math/log n) (/ (Math/log 2)) (tower/expt 2))
r (first (keep (fn [r] (when-let [ord (ma/multiplicative-order n r)]
(when (> ord log) r)))
(range)))
lim (min r (- n 1))]
;; 3. For all 2 ≤ a ≤ min(r, n−1), check that a does not divide n
;; (composite if so)
(when (not-any? (partial h/divisible? n) (range 2 (inc lim)))
;; 4. If n ≤ r, output prime.
(if (<= n r) true
(let [log2n (/ (Math/log n) (Math/log 2))
lim (->> (tower/sqrt (naive-phi r)) (* log2n) bigint)
lhs (fn [a] (p/poly-rem (p/mod-exp {1 1, 0 a} n n)
{r 1, 0 -1}))
rhs (fn [a] (p/poly-rem {n 1, 0 a}
{r 1, 0 -1}))]
;; 5. If (X+a)^n != (X^n)+a (mod X^r − 1,n) for ANY a from 1 to lim,
;; n is composite.
;; In other words, prime? = true iff (X+a)^n = (X^n)+a (mod X^r − 1,n)
;; for ALL a from 1 to lim
(every? (fn [a] (= (lhs a) (rhs a))) (range 1 lim))))))))
Gotcha!
Congrats! You made it this far.
If you test this function, though, you might be a bit disappointed with the results. Namely, it’s very slow, even for small inputs. The lag becomes apparent even with prime number inputs as small as \(n = 37\), and skyrockets exponentially as the size of \(n\) increases.
Indeed, despite all the groundbreaking features of the algorithm, speed and practicality are not among them. Ordinarily, I would be inclined to call coding all of this a waste of time and curse myself for not having bothered to research this earlier, but since I learned a fair bit of Clojure and math in the process, I can’t quite call it a waste.
However, it does mean I need to study up on elliptic curves!
References
- AKS primality test, Wikipedia
- AKS Primality Test (Primes is in P), Sibaprasad Maiti
- Long division of polynomials, Olga Radko, UCLA
- When is the AKS primality test actually faster than other tests?, CS Stack Exchange