This is a post about writing elegant and performant recursive algorithms in Rust. It makes heavy use of a pattern from Haskell called recursion schemes, but you don't need to know anything about that; it's just an implementation detail. Instead, as motivation, I have benchmarks showing a 14-34% improvement over the typical boxed pointer representation of recursive data structures in Rust.

# Performance test results⌗

These test results show a performance improvement of 34% for evaluating a very large expression tree (131072 elements, recursive depth 17). They were run on a 6th generation X1 carbon laptop with an Intel i7-8550U with 8MB CPU cache:

```Evaluate expression tree of depth 17 with standard method
time:   [722.18 µs 733.00 µs 746.43 µs]

Evaluate expression tree of depth 17 with my new collapse_layers method
time:   [477.87 µs 482.54 µs 488.58 µs]
```

The same tests, when run on an AMD Ryzen 9 3900X CPU with more than 64MB total cache (L1/L2/L3), still show a 14% speed improvement over the usual method.

```Evaluate expression tree of depth 17 with standard method
time:   [295.76 µs 295.89 µs 296.03 µs]

Evaluate expression tree of depth 17 with my new collapse_layers method
time:   [250.96 µs 251.12 µs 251.31 µs]
```

# Evaluating an expression language⌗

We're going to start with a simple expression language: addition, subtraction, multiplication – just enough to illustrate some concepts. You've probably seen something like this before, but if not, it's just a way to represent simple arithmetic as a tree of expressions. For example, an expression like `1 * (2 - 3)` would be written as (pseudocode) `Mul(1, Sub(2, 3))`.

``````#[derive(Debug, Clone)]
pub enum ExprBoxed {
a: Box<ExprBoxed>,
b: Box<ExprBoxed>,
},
Sub {
a: Box<ExprBoxed>,
b: Box<ExprBoxed>,
},
Mul {
a: Box<ExprBoxed>,
b: Box<ExprBoxed>,
},
LiteralInt {
literal: i64,
},
}
``````

This is a recursive expression language that uses boxed pointers to handle the recursive case. If you're not familiar with boxed pointers, a `Box<A>` is just the Rust way of storing a pointer to some value of type `A` - think of it as a box with a value of type `A` inside it. (If you're curious, there's more documentation here)

Using this data structure, we can write `Mul(1, Sub(2, 3))` as:

``````ExprBoxed::Mul {
a: Box::new(ExprBoxed::LiteralInt { literal: 1 }),
b: Box::new(ExprBoxed::Sub {
a: Box::new(ExprBoxed::LiteralInt { literal: 2 }),
b: Box::new(ExprBoxed::LiteralInt { literal: 3 }),
}),
}
``````

Evaluating expressions is pretty simple - it's just addition, subtraction, and multiplication. This recursive eval function provides a fairly elegant and readable example of a recursive algorithm:

``````impl ExprBoxed {
pub fn eval(&self) -> i64 {
match &self {
ExprBoxed::Add { a, b } => a.eval() + b.eval(),
ExprBoxed::Sub { a, b } => a.eval() - b.eval(),
ExprBoxed::Mul { a, b } => a.eval() * b.eval(),
ExprBoxed::LiteralInt { literal } => *literal,
}
}
}
``````

This algorithm has some issues:

• If we try to evaluate a sufficiently large expression it will fail with a stack overflow - we're not likely to hit that case here, but this is a real problem when working with larger recursive data structures.
• Each recursive `eval` call requires us to traverse a boxed pointer. This means we can't take advantage of cache locality - there's no guarantee that all these boxed pointers live in the same region of memory. 1

## A more cache-local structure⌗

We can fix that by writing an expression language using a Vec of individual expression nodes (guaranteeing memory locality), with boxed pointers replaced with newtype-wrapped vector indices.

``````#[derive(Debug, Clone, Copy)]
pub enum ExprLayer<A> {
Add { a: A, b: A },
Sub { a: A, b: A },
Mul { a: A, b: A },
LiteralInt { literal: i64 },
}

#[derive(Eq, Hash, PartialEq)]
pub struct ExprIdx(usize);
impl ExprIdx {
ExprIdx(0)
}
}

pub struct ExprTopo {
// nonempty, in topological-sorted order. guaranteed via construction.
elems: Vec<ExprLayer<ExprIdx>>,
}
``````

You might have noticed that we have used a generic parameter `A` rather than simply writing `ExprLayer<ExprIdx>`. Put a pin in that for now, we'll come back to that soon.

All our expressions are now guaranteed to be stored in local memory. Here's a sketch showing what the `Mul(1, Sub(2, 3))` expression would look like using this data structure.

``````[
idx_0:    Mul(idx_1, idx_2)
idx_1:    LiteralInt(1)
idx_2:    Sub(idx_3, idx_4)
idx_3:    LiteralInt(2)
idx_4:    LiteralInt(3)
]
``````

The nodes are stored in topological order, which means that for each node, all of its child nodes are stored at larger indices. To evaluate an `ExprTopo`, we can perform bottom up recursion: collapse leaf values into their parents, one `ExprLayer` at a time, until the entire `ExprTopo` structure has been collpased into a single value. Since it's topologically sorted, we can do this by iterating over the element vector in reverse order.

Let's see what evaluating this structure looks like in practice. It's not elegant. There's a bunch of `unsafe` code, but it does have better performance in benchmarks. Feel free to skim; in the next section we'll introduce an elegant API that removes the need to write `unsafe` code.

``````impl ExprTopo {
fn eval(self) -> i64 {
use std::mem::MaybeUninit;

let mut results = std::iter::repeat_with(|| MaybeUninit::<i64>::uninit())
.take(self.elems.len())
.collect::<Vec<_>>();

fn get_result_unsafe(results: &mut Vec<MaybeUninit<i64>>, idx: ExprIdx) -> i64 {
unsafe {
let maybe_uninit =
std::mem::replace(results.get_unchecked_mut(idx.0), MaybeUninit::uninit());
maybe_uninit.assume_init()
}
}

for (idx, node) in self.elems.into_iter().enumerate().rev() {
let result = {
// each node is only referenced once so just remove it, also we know it's there so unsafe is fine
match node {
ExprLayer::Add { a, b } => {
let a = get_result_unsafe(&mut results, a);
let b = get_result_unsafe(&mut results, b);
a + b
}
ExprLayer::Sub { a, b } => {
let a = get_result_unsafe(&mut results, a);
let b = get_result_unsafe(&mut results, b);
a - b
}
ExprLayer::Mul { a, b } => {
let a = get_result_unsafe(&mut results, a);
let b = get_result_unsafe(&mut results, b);
a * b
}
ExprLayer::LiteralInt { literal } => literal,
}
};
results[idx].write(result);
}

unsafe {
let maybe_uninit =
std::mem::replace(results.get_unchecked_mut(0), MaybeUninit::uninit());
maybe_uninit.assume_init()
}
}
}
``````

The problem here is that this is very difficult to read and write. Imagine having to write all of this by hand, for each recursive function. It would be tedious at best and error prone at worst.

## Factoring out duplicated code⌗

Every arm of the above match statement (except for `LiteralInt`) calls `get_result_unsafe` in pretty much the same way. We can start by factoring that out.

Now you can see why we made `ExprLayer<A>` parameterized over some `A`. Since it is parameterized over some `A`, we can apply a function to each `A` inside it, turning it into an `ExprLayer<B>`. We're going to write some code that's very similar to `Option::map` in the standard library.

``````impl<A> ExprLayer<A> {
#[inline(always)]
fn map<B, F: FnMut(A) -> B>(self, mut f: F) -> ExprLayer<B> {
match self {
ExprLayer::Add { a, b } => ExprLayer::Add { a: f(a), b: f(b) },
ExprLayer::Sub { a, b } => ExprLayer::Sub { a: f(a), b: f(b) },
ExprLayer::Mul { a, b } => ExprLayer::Mul { a: f(a), b: f(b) },
ExprLayer::LiteralInt { literal } => ExprLayer::LiteralInt { literal },
}
}
}
``````

If you're familiar with functional languages, this is basically just `fmap`.3

Now, we can write something like this:

``````impl ExprTopo {
fn eval(self) -> i64 {
use std::mem::MaybeUninit;

let mut results = std::iter::repeat_with(|| MaybeUninit::<i64>::uninit())
.take(self.elems.len())
.collect::<Vec<_>>();

for (idx, layer) in self.elems.into_iter().enumerate().rev() {
let layer: ExprLayer<i64> = layer.map(|idx| unsafe {
let maybe_uninit =
std::mem::replace(results.get_unchecked_mut(idx.0), MaybeUninit::uninit());
maybe_uninit.assume_init()
});

let result = match layer {
ExprLayer::Add { a, b } => a + b,
ExprLayer::Sub { a, b } => a - b,
ExprLayer::Mul { a, b } => a * b,
ExprLayer::LiteralInt { literal } => literal,
};
results[idx].write(result);
}

unsafe {
let maybe_uninit =
maybe_uninit.assume_init()
}
}
}
``````

## Making it generic⌗

Ok, that's a start. Unfortunately, we still have to write all this boilerplate for every recursive function, even though the only part that really matters is this block:

``````let result = match layer {
ExprLayer::Add { a, b } => a + b
ExprLayer::Sub { a, b } => a - b
ExprLayer::Mul { a, b } => a * b
ExprLayer::LiteralInt { literal } => literal,
}
``````

This code takes `layer`, a value of type `ExprLayer<i64>`, and consumes it to create `result`, a value of type `i64`. What if, instead of `ExprLayer<i64> -> i64`, we use a function of type `ExprLayer<A> -> A`?

This function lets us provide an arbitrary function of type `ExprLayer<A> -> A` and uses it to collapse all the layers in an `ExprTopo` structure into a single value:

``````impl ExprTopo {
fn collapse_layers<F: FnMut(ExprLayer<A>) -> A>(self, mut collapse_layer: F) -> A {
use std::mem::MaybeUninit;

let mut results = std::iter::repeat_with(|| MaybeUninit::<A>::uninit())
.take(self.elems.len())
.collect::<Vec<_>>();

for (idx, layer) in self.elems.into_iter().enumerate().rev() {
let result = {
let layer = layer.map(|x| unsafe {
let maybe_uninit =
std::mem::replace(results.get_unchecked_mut(x.0), MaybeUninit::uninit());
maybe_uninit.assume_init()
});
collapse_layer(layer)
};
results[idx].write(result);
}

unsafe {
let maybe_uninit =
maybe_uninit.assume_init()
}
}
}
``````

Nice. Now we can write:

``````impl ExprTopo {
pub fn eval(self) -> i64 {
self.collapse_layers(|expr| match expr {
ExprLayer::Add { a, b } => a + b,
ExprLayer::Sub { a, b } => a - b,
ExprLayer::Mul { a, b } => a * b,
ExprLayer::LiteralInt { literal } => literal,
})
}
}
``````

It's pretty much the same logic as the original `eval` functions, without any of the boilerplate. Since there's less boilerplate, it's easier to review and there's less room for bugs. Also, it retains all the performance benefits of the previous `eval` implementation - it's both more elegant and more performant than the traditional representation of recursive expression trees in rust.

# Constructing Exprs⌗

Let's write a function to build an `ExprTopo` value from the `ExprBoxed` representation. Just as before, `map` helps us keep it concise. Feel free to skim this one too, we'll be abstracting over the specifics just like we did with `collapse_layers`:

``````impl ExprTopo {
fn from_boxed(seed: &ExprBoxed) -> Self {
let mut frontier: VecDeque<&ExprBoxed> = VecDeque::from([seed]);
let mut elems = vec![];

// expand layers to build a vec of elems while preserving topo order
while let Some(seed) = { frontier.pop_front() } {
let layer = match seed {
ExprBoxed::Sub { a, b } => ExprLayer::Sub { a, b },
ExprBoxed::Mul { a, b } => ExprLayer::Mul { a, b },
ExprBoxed::LiteralInt { literal } => ExprLayer::LiteralInt { literal: *literal },
};
let layer = layer.map(|seed| {
frontier.push_back(seed);
// idx of pointed-to element determined from frontier + elems size
ExprIdx(elems.len() + frontier.len())
});

elems.push(layer);
}

Self { elems }
}
}
``````

## Making it generic⌗

Just as with `collapse_layers`, we only really care about the `match` expression here:

``````let layer = match seed {
ExprBoxed::Sub { a, b } => ExprLayer::Sub { a, b },
ExprBoxed::Mul { a, b } => ExprLayer::Mul { a, b },
ExprBoxed::LiteralInt { literal } => ExprLayer::LiteralInt { literal: *literal },
};
``````

This matches on `seed`, a value of type `&ExprBoxed`, and consumes it to create `layer`, a value of type `ExprLayer<i64ExprBoxed>`. What if, instead of `i64ExprBoxed -> ExprLayer<i64ExprBoxed>`, we use a function of type `A -> ExprLayer<A>`?

Fortunately, just as with `collapse_layers`, we can separate the machinery of recursion from the actual recursive (or, in this case, co-recursive) logic.

``````impl ExprTopo {
fn expand_layers<A, F: Fn(A) -> ExprLayer<A>>(seed: A, expand_layer: F) -> Self {
let mut frontier = VecDeque::from([seed]);
let mut elems = vec![];

// repeatedly expand layers to build a vec of elems while preserving topo order
while let Some(seed) = frontier.pop_front() {
let layer = expand_layer(seed);

let layer = layer.map(|seed| {
frontier.push_back(seed);
// idx of pointed-to element determined from frontier + elems size
ExprIdx(elems.len() + frontier.len())
});

elems.push(layer);
}

Self { elems }
}
}
``````

This lets us write `from_boxed` as:

``````impl ExprTopo {
pub fn from_boxed(ast: &ExprBoxed) -> Self {
Self::expand_layers(ast, |seed| match seed {
ExprBoxed::Sub { a, b } => ExprLayer::Sub { a, b },
ExprBoxed::Mul { a, b } => ExprLayer::Mul { a, b },
ExprBoxed::LiteralInt { literal } => ExprLayer::LiteralInt { literal: *literal },
})
}
}
``````

Nice and, as promised, elegant.

# Testing for Correctness⌗

I used proptest to test this code for correctness. It generates many expression trees, each of which is evaluated via both `eval` methods. I then assert that they have the same result. 4

This actually helped me find a bug! In my first implementation of `expand`, I used a stack instead of a queue for the frontier, which ended up mangling the order of the expression tree. Since proptest is awesome, it not only found this bug but reduced the failing test case to `Add (0, Sub(0, 1))`.

``````// generate a bunch of expression trees and evaluate them via each method
#[cfg(test)]
proptest! {
#[test]
fn expr_eval(boxed_expr in arb_expr()) {
let eval_boxed = boxed_expr.eval();
let eval_via_collapse = ExprTopo::from_boxed(&boxed_expr).eval();

assert_eq!(eval_boxed, eval_via_collapse);
}
}

#[cfg(test)]
pub fn arb_expr() -> impl Strategy<Value = ExprBoxed> {
let leaf = any::<i8>().prop_map(|x| ExprBoxed::LiteralInt { literal: x as i64 });
leaf.prop_recursive(
8,   // 8 levels deep
256, // Shoot for maximum size of 256 nodes
10,  // We put up to 10 items per collection
|inner| {
prop_oneof![
a: Box::new(a),
b: Box::new(b)
}),
(inner.clone(), inner.clone()).prop_map(|(a, b)| ExprBoxed::Sub {
a: Box::new(a),
b: Box::new(b)
}),
(inner.clone(), inner).prop_map(|(a, b)| ExprBoxed::Mul {
a: Box::new(a),
b: Box::new(b)
}),
]
},
)
}
``````

# Testing for performance⌗

For performance testing, we used criterion to benchmark the simple `ExprBoxed::eval` vs `ExprTopo::eval`. This code basically just builds up a really big (as in, 131072 nodes) recursive structure (using `expand`/`collapse`, because they're honestly really convenient) and evaluates it a bunch of times. I also ran this test on recursive structures of other sizes, because graphs are cool. You can find the benchmarks defined here.

```Evaluate expression tree of depth 17 with standard boxed method
time:   [722.18 µs 733.00 µs 746.43 µs]

Evaluate expression tree of depth 17 with my collapse_layers method
time:   [477.87 µs 482.54 µs 488.58 µs]
```

Evaluating a boxed expression of depth 17 takes an average 733 µs. Evaluating an expression stored in our `ExprTopo` takes an average of 482 µs. That's a 34% improvement. Running these tests with expression trees of different depths generated via the above method yields similar results. The standard boxed method is slightly faster for expression trees of size 256 or less. That said, this test provides pretty much optimal conditions with regard to pointer locality, because there are no other heap allocations to fragment things and force the boxed pointers to use different regions of memory.

# To be continued⌗

We started with a simplified non-generic version of this algorithm to build understanding. In future blog posts, I plan on showing how I made it generic, going into more detail on how I optimized it for performance (`MaybeUninit` absolutely slaps, as do stack machines), and how I used it to implement an async file tree search tool using `tokio::fs`.

# Thank you⌗

Thank you to Fiona, Rain, Eliza and Gankra, among others, for reviewing drafts of this post.

# Change notes⌗

• 07/24/2022: renamed fold/generate to collapse/expand

1. If you're not sure what I mean by cache locality, or you want much more information on it than I can provide, there's a great rust performance optimization resource here. ↩︎

2. If you're not familiar with functional languages and are now wondering what `fmap` is, it's a method provided by a trait called `Functor`. It represents the ability to map a function `A -> B` over some arbitrary structure - if we have a `Functor` instance for `F`, then we can map a function over `F<A>`, for any `A`. `F` could be an option, or a list, or a tree - any structure parameterized over some value. `map` provides an implementation of `fmap` (as in _f_unction map) that's specialized to `ExprLayer`. If you're curious, read more here. ↩︎

3. If you're really familiar with functional languages, you might point out that it's not quite `fmap`, but that's fine for our limited use case. 2 ↩︎

4. I learned this technique from my partner Rain ↩︎