benchmarking permutations
I was dealing with a Python implementation for a prisoner problem simulation that involves permutations. You can simulate a permutation by running random.sample(list) or by running np.random.permutation(n_prisoners) but if you want to then walk the path of a cycle in a permutation you'll quickly learn that you're dealing with a loop that doesn't easily vectorise. It involves code like this:
def max_cycle_length_python(perm):
"""Find the maximum cycle length in a permutation (pure Python)."""
n = len(perm)
visited = [False] * n
max_len = 0
for start in range(n):
if visited[start]:
continue
length = 0
current = start
while not visited[current]:
visited[current] = True
length += 1
current = perm[current]
if length > max_len:
max_len = length
return max_len
This got me thinking, maybe we can have Claude make different implementations and we can see if there is a speedup! So I made it try pure Python, NumPy, Numba, Rust and Mojo.
| Implementation | Time (s) | Sims/sec | Speedup |
|---|---|---|---|
| Pure Python | 0.752 | 133,038 | 1.0x |
| NumPy | 1.525 | 65,570 | 0.5x |
| Numba | 0.263 | 380,111 | 2.9x |
| Rust | 0.102 | 980,604 | 7.4x |
| Mojo | 0.161 | 619,273 | 4.7x |
I got a notebook, it spews out results and while it's impressive that it was able to generate all of this ... I still felt compelled to have a look at some of the code. For starters, I spotted that the original Python benchmark still used numpy code in it, which wasn't fair. I also spotted that numba code added a JIT around the cycle_path function, but omitted the permutation generation.
before
@njit
def max_cycle_length_numba(perm):
"""Find the maximum cycle length in a permutation (Numba JIT)."""
n = len(perm)
visited = np.zeros(n, dtype=np.bool_)
max_len = 0
for start in range(n):
if visited[start]:
continue
length = 0
current = start
while not visited[current]:
visited[current] = True
length += 1
current = perm[current]
if length > max_len:
max_len = length
return max_len
def simulate_numba(n_prisoners: int, n_sims: int) -> list[int]:
"""Run simulations using Numba JIT-compiled function."""
results = []
for _ in range(n_sims):
perm = np.random.permutation(n_prisoners).astype(np.int64)
results.append(max_cycle_length_numba(perm))
return results
after
@njit
def _simulate_numba_jit(n_prisoners: int, n_sims: int) -> np.ndarray:
"""Fully JIT-compiled simulation loop."""
results = np.empty(n_sims, dtype=np.int64)
for i in range(n_sims):
perm = np.random.permutation(n_prisoners)
# Inline max_cycle_length logic for performance
n = len(perm)
visited = np.zeros(n, dtype=np.bool_)
max_len = 0
for start in range(n):
if visited[start]:
continue
length = 0
current = start
while not visited[current]:
visited[current] = True
length += 1
current = perm[current]
if length > max_len:
max_len = length
results[i] = max_len
return results
When I fixed both, the simulation numbers looked a bunch different.
| Implementation | Time (s) | Sims/sec | Speedup |
|---|---|---|---|
| Pure Python | 0.917 | 109,107 | 1.0x |
| NumPy | 1.503 | 66,534 | 0.6x |
| Numba | 0.1 | 1,002,954 | 9.2x |
| Rust | 0.103 | 971,905 | 8.9x |
| Mojo | 0.16 | 624,544 | 5.7x |
I'm reasonably comfortable with all the Python benchmarks but then we get to the rust and mojo code. This is tricky. I can prompt Claude to try and dive deeper. But in the end, those languages are new to me. So my ability to steer it are limited. The main thing that I had it try were to ask Claude about sorting algorithms and to see if it might be better to use a standard library instead of rolling a custom implementation ... but that's about it on the short term.
The numbers didn't really change much and even though I am simulating a lot of times here it all feels within the margin of error.
| Implementation | Time (s) | Sims/sec | Speedup |
|---|---|---|---|
| Pure Python | 0.922 | 108,467 | 1.0x |
| NumPy | 1.505 | 66,465 | 0.6x |
| Numba | 0.108 | 926,469 | 8.5x |
| Rust | 0.101 | 991,451 | 9.1x |
| Mojo | 0.158 | 632,052 | 5.8x |
How to think about this stuff
I can't help but play the devil's advocate here. If the point is to achieve a speedup then it's clear that Claude can just do that for you. It may not be the best way to do a benchmark, but the Rust implementation is on-par with the faster Numba one. If the bad benchmark led me to the Rust implementation it would still be a win, right? So maybe it doesn't matter if the comparison is fair or not.
A good reason to "no" to this is that there is a difference between building something that's "easy" and building something that's "simple". And that difference tends to be worth something. Does it really make sense to add Rust to a Python project? Or can we stick to a simpler Python stack by sticking to Numba? It depends on the project, sure, but that's a design decision, not a mere implementation detail.
It's not just the fact that I want to be able to work on a codebase even if the LLM is down. It's also that it's one thing to let go the how-part of something. I don't know how to write assembly, after all. So at some level I won't exactly know what's going on with my code. But it's another thing not to be able to explain why you designed something the way it is. And if Claude is going to help me, that's the part I can't let go.
It's incredibly clear that Claude is doing something impressive here, something that I clearly don't want to ignore. But at the same time I am also acutely aware that you loose something if you don't take the time to look at what you're getting back. My mind is really going back and forth on this topic as I expose myself to more and more of these tools and tasks. Depending on how things go I might think differently in a years time.