Arithmetic performance in the Chinese Remainder Theorem

Published:

There’s a common belief in the analysis of algorithms that arithmetic operations take roughly constant time. This is true when working with types of a fixed size, like 64-bit integers, but wrong for “big” problems.

In “big” problems we encounter objects with unbounded sizes. Arbitrary length integers, arbitrary degree polynomials, recurrence relations of arbitrary orders, and so on. In these problems it matters how long it takes to do operations on objects of different sizes. The Chinese Remainder Theorem is a good example of this.

The Chinese Remainder Theorem (CRT) states that the system of congruences

\[\begin{align*} x &\equiv a_1 \pmod{m_1} \\ x &\equiv a_2 \pmod{m_2} \\ &\cdots \\ x &\equiv a_n \pmod{m_n} \end{align*}\]

has a unique solution $x$ modulo $m_1 m_2 \cdots m_n$. The term “CRT” also refers to the process of finding the smallest solution, and this is where things get interesting.

In one sense, CRT can be done in “constant time.” If we set $M = m_1 \cdots m_n$, then we can write down an explicit solution:

\[x = \sum_{k = 1}^n a_k \frac{M}{m_k} \left( \left(\frac{M}{m_k}\right)^{-1} \bmod m_k \right).\]

The $k$th term is divisible by $m_i$ for every $i \neq k$, and is congruent to $a_k$ modulo $m_k$, so we’re done!

This solution looks “constant time” because it seems to require a fixed number of arithmetic operations, but each implied arithmetic operation operates on arbitrary length integers, and we need to consider how long these things take.

To get a sense for how things can break down, consider the system

\[\begin{align*} x &\equiv 1 \pmod{2} \\ x &\equiv 2 \pmod{3} \\ x &\equiv 3 \pmod{5} \\ &\cdots \\ x &\equiv n \pmod{p_n}, \end{align*}\]

where $p_n$ is the $n$th prime. For $n = 6$, the formula gives

\[x = 15015 + 40040 + 18018 + 102960 + 81900 + 41580 = 299513.\]

This is the correct answer, but way bigger than the smallest one. Each term of the summand is about the same order of magnitude as the modulus $2 \times 3 \times 5 \times 7 \times 11 \times 13 = 30030$, so this is guaranteed to be much bigger than the smallest answer. In fact, if the modular inverses behave randomly, then the $k$th term of our sum would be, on average, $k M / 2$. This would give our formula an approximate magnitude of $M \times (n^2 / 4)$, when we know the real answer can’t be bigger than $M$ itself!

The worst part about the formula, performance-wise, is not its too-big answer—we can reduce that with a single division. The worst part is that there are $n$ inverses $(M / m_k)^{-1} \bmod m_k$ to compute, and the numbers $M / m_k$ can be pretty big.

There is a CRT algorithm which slightly sidesteps this issue. The idea is to write our solution in “mixed radix” form

\[x = d_0 + d_1 m_1 + d_2 m_1 m_2 + \cdots + d_{n - 1} m_1 m_2 \cdots m_{n - 1},\]

which is unique if we stipulate $0 \leq d_k < m_{k + 1}$. Then our original system turns into the following triangular one:

\[\begin{align*} d_0 &\equiv a_1 \pmod{m_1} \\ d_0 + d_1 m_1 &\equiv a_2 \pmod{m_2} \\ d_0 + d_1 m_1 + d_2 m_1 m_2 &\equiv a_3 \pmod{m_3} \\ &\cdots \\ d_0 + d_1 m_1 + d_2 m_1 m_2 + \cdots + d_{n - 1} m_1 \cdots m_{n - 1} &\equiv a_n \pmod{m_n}. \end{align*}\]

We can solve this system from top to bottom for the $d_k$. In the first step we need no modular inverses, in the second we need to invert $m_1 \bmod m_2$, in the third we need to invert $m_1 m_2 \bmod m_3$, in the fourth $m_1 m_2 m_3 \bmod m_4$, and so on.

The difference between the two methods is that the formula computes inverses of numbers that are all around the same size, while the triangular approach starts with small numbers and works up to big ones.

Implementations and profiling

Below are implementations of the formula-based method and the triangular system method in Python, using sympy for modular inverses and prime calculations. It’s ready to be run with line_profiler, though I’ve done that and put the outputs below.

from math import prod
from sympy import prime, primepi, log, primerange, mod_inverse
from line_profiler import profile
from itertools import islice
from functools import cache

@profile
def naiveCRT(ims, mods):
    M = prod(mods)

    terms = []

    for im, mod in zip(ims, mods):
        a = M // mod
        b = mod_inverse(a, mod)
        term = im * a * b
        terms.append(term)

    soln = sum(terms)
    smallest = soln % M

    return terms, soln, smallest

@profile
def triangularCRT(ims, mods):
    ds = []

    part_sum = 0
    part_prod = 1
    for im, mod in zip(ims, mods):
        inv = mod_inverse(part_prod, mod)
        new_d = (im - part_sum) * inv
        new_d %= mod
        ds.append(new_d)
        part_sum += ds[-1] * part_prod
        part_prod *= mod

    return part_sum

def testSystem(n, naive=True):
    # sympy is not very good at generating the first n primes. this is a faster
    # method.
    k = 2
    while primepi(int(k * n * log(n))) < n:
        k += 1

    ims = list(range(1, n + 1))
    ps = list(islice(primerange(int(k * n * log(n))), n))
    if naive:
        return naiveCRT(ims, ps)

    return triangularCRT(ims, ps)

if __name__ == "__main__":
    testSystem(50000, True)
    testSystem(50000, False)

And the profiling results:

Timer unit: 1e-06 s

Total time: 17.2698 s
File: /home/rdb/crt.py
Function: triangularCRT at line 24

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    24                                           @profile                                 
    25                                           def triangularCRT(ims, mods):            
    26         1          1.4      1.4      0.0      ds = []                              
    27                                                                                    
    28         1          0.3      0.3      0.0      part_sum = 0                         
    29         1          0.2      0.2      0.0      part_prod = 1                        
    30     50001       8836.1      0.2      0.1      for im, mod in zip(ims, mods):       
    31     50000   12761469.0    255.2     73.9          inv = mod_inverse(part_prod, mod)
    32     50000     928981.4     18.6      5.4          new_d = (im - part_sum) * inv    
    33     50000    2329583.1     46.6     13.5          new_d %= mod                     
    34     50000       6508.7      0.1      0.0          ds.append(new_d)                 
    35     50000     779081.8     15.6      4.5          part_sum += ds[-1] * part_prod   
    36     50000     455358.8      9.1      2.6          part_prod *= mod                 
    37                                                                                    
    38         1          0.5      0.5      0.0      return part_sum                      


Total time: 35.5329 s
File: /home/rdb/crt.py
Function: naiveCRT at line 7

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     7                                           @profile                          
     8                                           def naiveCRT(ims, mods):          
     9         1     464805.5 464805.5      1.3      M = prod(mods)                
    10                                                                             
    11         1          1.5      1.5      0.0      terms = []                    
    12                                                                             
    13     50001      11683.0      0.2      0.0      for im, mod in zip(ims, mods):
    14     50000    4864532.1     97.3     13.7          a = M // mod              
    15     50000   27588076.0    551.8     77.6          b = mod_inverse(a, mod)   
    16     50000    1885445.3     37.7      5.3          term = im * a * b         
    17     50000       9369.0      0.2      0.0          terms.append(term)        
    18                                                                             
    19         1     708921.2 708921.2      2.0      soln = sum(terms)             
    20         1         80.7     80.7      0.0      smallest = soln % M           
    21                                                                             
    22         1          1.7      1.7      0.0      return terms, soln, smallest  


 17.27 seconds - /home/rdb/crt.py:24 - triangularCRT
 35.53 seconds - /home/rdb/crt.py:7 - naiveCRT

For $n = 50,000$, the triangular method was around twice as fast, mostly because the bulk of the runtime is in computing the modular inverses, and the triangular method was twice as fast on average in this part.

Ignoring the inverses, the inner loop in the formula method took approximately 130 microseconds per iteration and only 90 microseconds in the triangular method. This is a less favorable comparison, but it’s clear that even caching the inverses beforehand would not save the formula.

Back of the envelope calculations

The product of the first $n$ primes is something like $n^n$, which has $O(n \log n)$ bits. Thus $M / m_k$ has something like $O(n \log n)$ bits, and computing a modular inverse of this number with the Euclidean algorithm might cost about $O(n^2 \log^2 n)$, which we need to do $n$ times. This puts the total runtime of the formula approach at around $O(n^3 \log^2 n)$.

For the triangular method, our cost will be something like

\[\begin{equation*} \sum_{k = 1}^n (k \log k)^2. \end{equation*}\]

There should probably be $O$’s here, and constants are floating around, but the point is that this sum works out to be a constant factor smaller than $n^3 \log^2 n$. Maple can compute the following asymptotic ratio:

\[\begin{equation*} \frac{\sum_{k = 1}^n (k \log k)^2}{n^3 \log^2 n} = \frac{1}{3} + O \left( \frac{1}{\log n} \right). \end{equation*}\]

So, at least for inverses, the triangular method should take around a third of the time for large $n$, but it approaches that point pretty slowly.