How to improve acceptance rate in speculative decoding?

By

Mirai team

Inference of LLMs requires reading large amounts of data from memory while doing relatively little compute with that data, which means that compute is significantly underutilized.

https://github.com/trymirai/uzu can turn that underutilized compute into higher decoding throughput by trying to cheaply guess the next token, computing the next token for both the current sequence and current sequence + guessed next token, and then, if the next token matches our guess, we can take both at once, doubling the throughput. This is the simplest form of speculative decoding.

The simplest way to improve upon this is by realizing that the LLM is often unsure about what the next token should be.

Let's take an example:

"The random number from 1 to 10 inclusive is"
"The random number from 1 to 10 inclusive is"

An idealized LLM outputs uniform 1/10 probability for every number 1–10. A good speculator would output the same. In that situation, the naive scheme takes a random number from the speculator, then takes another random number from an LLM and can only accept the bonus token in 1/10 cases when they happen to match.

But here the speculator exactly predicted the LLM's distribution.

We should be able to always take its guess without the loss of generation quality.

We achieve this by sampling from both the speculator and the LLM via the Gumbel-max trick, sharing the same seed.

This will achieve a 100% acceptance rate in the toy example above, and a near-optimal acceptance rate in more complicated real-world cases, while being much simpler than the mainstream rejection-sampling-like algorithm that most other inference engines use.

Check our Gumbel kernel implementation.

Inference of LLMs requires reading large amounts of data from memory while doing relatively little compute with that data, which means that compute is significantly underutilized.

https://github.com/trymirai/uzu can turn that underutilized compute into higher decoding throughput by trying to cheaply guess the next token, computing the next token for both the current sequence and current sequence + guessed next token, and then, if the next token matches our guess, we can take both at once, doubling the throughput. This is the simplest form of speculative decoding.

The simplest way to improve upon this is by realizing that the LLM is often unsure about what the next token should be.

Let's take an example:

"The random number from 1 to 10 inclusive is"

An idealized LLM outputs uniform 1/10 probability for every number 1–10. A good speculator would output the same. In that situation, the naive scheme takes a random number from the speculator, then takes another random number from an LLM and can only accept the bonus token in 1/10 cases when they happen to match.

But here the speculator exactly predicted the LLM's distribution.

We should be able to always take its guess without the loss of generation quality.

We achieve this by sampling from both the speculator and the LLM via the Gumbel-max trick, sharing the same seed.

This will achieve a 100% acceptance rate in the toy example above, and a near-optimal acceptance rate in more complicated real-world cases, while being much simpler than the mainstream rejection-sampling-like algorithm that most other inference engines use.

Check our Gumbel kernel implementation.

Deploy and run models of any architecture directly on Apple devices.

Deploy and run models of any architecture directly on Apple devices.

On-device layer for AI model makers & products.

On-device layer for AI model makers & products.

On-device layer for AI model makers & products.