Making zears Fast

Hola! I enjoy implementing algorithms in Rust, and AEZ came up naturally because I needed it for other things. And while there is a crate that provides AEZ in Rust, it is “only” a wrapper around the C reference implementation, and comes with some limitations. Therefore, I wanted to create a pure-Rust implementation – even if it’s not useful, it still serves as a learning exercise.

While the basic functionality was straightforward to implement, I also wanted to get a good speed. The reference implementation is quite optimized, so the bar is pretty high.

Benchmarks and profilers

The key to optimizing our algorithm is to not fly blind. This means two things:

  • We need benchmarks. Those allow us to assess how fast our implementation currently is, and whether a change that we make improves or worsens the performance.
  • We need a profiler. This allows us to see where the bottleneck in our code lies.

Especially for the benchmarks it is useful to add them early, so that you can later go back via git and understand the performance changes over time.

To set things up for zaers, we use the following setup:

  • We use criterion to imlement our benchmarks. This also allows us to use critcmp in order to quickly compare different benchmark results.
  • We use samply as a profiler. I’ve recently discovered it, and it seems like a nice UX improvement over plain old perf.

Finally, it's nice to set up a profiling compilation profile, which enables optimizations but keeps debug information. For that, we append to our Cargo.toml:

[profile.profiling]
inherits = "release"
debug = true

With this, we can see the impact of changes:

cargo bench --bench=zears -- --save-baseline=base
# <... apply changes ...>
cargo bench --bench=zears -- --save-baseline=unroll-loop
critcmp base unroll-loop

Enable AES inlining

The AEZ algorithm relies heavily on the use of AES, or rather, AES “round” primitive. Many modern CPUs provide support for this via some sort of AESENC instruction, which provides a huge speed-up compared to software implementations. Luckily, the aes crate not only provides a software implementation, but makes use of the hardware acceleration if available.

Yet, in commit d77f4ad, I’ve manually added hardware AES use to zears for a huge speedup (from 187 MiB/s to 522 MiB/s). What gives?

First, I thought that the speedup comes from not having to load the keys into a __m128i every iteration. That, however, does not account for the huge difference.

Then, as I had written in a comment in 5c192ad, I thought it was because the optimizer generates code with vectorized AES instructions, e.g. by using _mm256_aesenc_epi128 to handle two blocks of data at once. I was misled by the assembly instruction being called VAESENC.

The actual reason why the reimplementation was so much faster is because the compiler does not do cross-crate inlining. Therefore, every call to aes::hazmat::cipher_round caused an actual function call (and it is called a lot!). By having the AES instructions in the crate, the compiler happily inlines the function.

We can confirm that by compiling with lto = true, which allows for cross-crate inlining. And indeed, this produces code that is just as fast!

SIMD

SIMD (single instruction, multiple data) describes instructions that operate on multiple values at the same time. For AEZ, this is great, because we often deal with 16 bytes of data (a block) at once. By leveraging SIMD, we can compute the xor of two blocks with a single instruction, instead of needing a xor for each byte.

Unfortunately, SIMD generally requires you to delve into architecture-specific code. Luckily, Rust has an abstraction in the standard library, although it is currently feature gated to nightly compilers.

With std::simd, we can represent a block as a u8x16, and compute the xor between two blocks easily as a ^ b – which compiles to efficient SIMD instructions.

This gives another huge speedup, going from 1.7 GiB/s (achieved via various other small improvements) to 2.9 GiB/s.

Avoiding duplicated work

This seems like a bit of a no-brainer, but if you avoid doing (unnecessary) work, your algorithm will run faster! The hard part is to identify which work is unnecessary.

In zears, there are multiple instances of this, each one identified manually:

We can compute the value of e(j, i + 1) easily from e(j, i), which saves re-doing the exponentiation of the key from zero. This was implemented early, in commit 60fb0c6, speeding up the encryption from 88 MiB/s to 140 MiB/s.

Similarly, we see that we only ever need 8 multiples of key_l, so we can pre-compute those 8 values once. This was done in in 0009a24, going from 594 MiB/s to 723 MiB/s.

An optimization that was a bit harder to find is 79e1a87. Here, we see that the value used in e() is actually constant, and only the other terms change. However, the subterm ki_p_i is also the same for 8 consecutive values! Therefore, by computing 8 values at once, we can avoid re-computing the pre_xored value 8 times.

The last change actually provides very little speedup, as we trade a (cheap) xor operation for potentially more branch mispredictions and more complex logic for the optimizer to optimize. However, it opens up more follow-up optimizations.

Loop unrolling

Loop unrolling is a technique where multiple iterations of a loop are written out “by hand”, in order to save efficiency by not having the loop control overhead and fewer branch mispredictions.

Together with the previous optmization, unrolling the loop in pass_two provides a speed increase from 3.6 GiB/s to 4.0 GiB/s. Interestingly, unrolling the loop in pass_one does not help, but actually decreases speed.

Unfortunately, doing the loop unrolling requires quite a bit of extra code: We now need an iterator to go over 8 pairs of blocks at once, we need to duplicate the code for the unrolled iterations, and we still need to handle the case where the number of blocks is not evenly divisible by 8.

The second part is nicely achieved with a macro:

macro_rules! unroll_pairs {
    (
        $accessor:expr;
        setup_unrolled => $setup_unrolled:block;
        setup_single => $setup_single:block;
        roll ($left:ident, $right:ident) => $roll:block;
    ) => {
        for (left, right) in $accessor.pairs_8_mut() {
            $setup_unrolled;

            let [l0, l1, l2, l3, l4, l5, l6, l7] = left;
            let [r0, r1, r2, r3, r4, r5, r6, r7] = right;

            let $left = l0;
            let $right = r0;
            $roll;
            let $left = l1;
            let $right = r1;
            $roll;
            let $left = l2;
            let $right = r2;
            $roll;
            let $left = l3;
            let $right = r3;
            $roll;
            let $left = l4;
            let $right = r4;
            $roll;
            let $left = l5;
            let $right = r5;
            $roll;
            let $left = l6;
            let $right = r6;
            $roll;
            let $left = l7;
            let $right = r7;
            $roll;
        }

        for (left, right) in $accessor.suffix_8_mut() {
            let $left = left;
            let $right = right;
            $setup_single;
            $roll;
        }
    };
}

// ... later:

unroll_pairs! { blocks;
    setup_unrolled => {
        evals_for_s.refill();
    };
    setup_single => {
        if evals_for_s.len == 0 {
            evals_for_s.refill();
        }
    };
    roll (raw_wi, raw_xi) => {
        e1_eval.advance();
        let wi = Block::from(*raw_wi);
        let xi = Block::from(*raw_xi);
        let yi = wi ^ evals_for_s.blocks[8 - evals_for_s.len];
        let zi = xi ^ evals_for_s.blocks[8 - evals_for_s.len];
        let ci_ = yi ^ e0_eval.eval(zi);
        let ci = zi ^ e1_eval.eval(ci_);

        ci.write_to(raw_wi);
        ci_.write_to(raw_xi);

        y = y ^ yi;
        evals_for_s.len -= 1;
    };
}

Make common cases fast

When we look at Block::mul, we have implemented it for any factor. However, we see that we often multiply by 2. Similarly, Block::shl works for any amount, but very often it is called with an amount of 1 – so we can optimize for this case:

fn shl(self, rhs: u32) -> Block {
    // We often use a shift by one, for example in the multiplication. We therefore optimize
    // for this special case.
    #[cfg(feature = "simd")]
    {
        if rhs == 1 {
            return Block((self.0 << 1) | (self.0.shift_elements_left::<1>(0) >> 7));
        }
    }
    #[cfg(not(feature = "simd"))]
    {
        if rhs == 1 {
            return Block([
                (self.0[0] << 1) | (self.0[1] >> 7),
                (self.0[1] << 1) | (self.0[2] >> 7),
                (self.0[2] << 1) | (self.0[3] >> 7),
                (self.0[3] << 1) | (self.0[4] >> 7),
                (self.0[4] << 1) | (self.0[5] >> 7),
                (self.0[5] << 1) | (self.0[6] >> 7),
                (self.0[6] << 1) | (self.0[7] >> 7),
                (self.0[7] << 1) | (self.0[8] >> 7),
                (self.0[8] << 1) | (self.0[9] >> 7),
                (self.0[9] << 1) | (self.0[10] >> 7),
                (self.0[10] << 1) | (self.0[11] >> 7),
                (self.0[11] << 1) | (self.0[12] >> 7),
                (self.0[12] << 1) | (self.0[13] >> 7),
                (self.0[13] << 1) | (self.0[14] >> 7),
                (self.0[14] << 1) | (self.0[15] >> 7),
                (self.0[15] << 1),
            ]);
        }
    }
    Block::from(self.to_int() << rhs)
}

Similarly, we sometimes need to count up an index in a block. Instead of converting from an integer every time, we instead count up the last byte manually. And then, only once every 256 iterations, we need to look at the previous byte:

macro_rules! add_ladder {
    ($ar:expr, $lit:literal) => {
        $ar[$lit] = $ar[$lit].wrapping_add(1);
    };
    ($ar:expr, $lit:literal $($rest:literal) +) => {
        $ar[$lit] = $ar[$lit].wrapping_add(1);
        if $ar[$lit] == 0 {
            add_ladder!($ar, $($rest) +);
        }
    };
}

pub fn count_up(&mut self) {
    add_ladder!(self, 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0);
}

Conclusion

The tricks here aren’t new by any means, but it was fun (and sometimes tedious) to chase the performance of the reference C implementation. Seeing changes pay off, and the performance of zears creep higher and higher is fun and rewarding!

Flags Speed
base 523 MiB/s
+simd 890 MiB/s
target-cpu=native 1745 MiB/s
+simd,target-cpu=native 3.6 GiB/s
aez 5.0 GiB/s