Optimizing Diffusion Decoders with Depth Pruning (Using ODE Regression)

Optimizing Diffusion Decoders with Depth Pruning (Using ODE Regression)
TL;DR
We delete up to half the layers in our diffusion decoder’s transformer and regress to match its pre-pruned outputs, for nearly no loss in quality!
Recap On Diffusion Decoders
In a previous blog we discussed that the runtime of a diffusion model in general is governed by O(N*D*T) where N = number of tokens in a latent, D = depth of model, T = number of diffusion steps.
We’ve talked about optimizing N (better encoders) and T (few-step distillation) before, but what about D (shallower models)? Previously, this seemed largely untouchable, as shallower models are strictly worse, right? Well, we’ve had some results that show otherwise!
For some context, before attempting any diffusion speedups for our world model, we like to start by attempting them on our decoder (which is also a diffusion model). We won’t be introducing our diffusion decoder here, and will instead focus mostly on our new pruning experiments for our diffusion decoder. However, we still think it’s worthwhile to quickly go over our setup to showcase that it’s quite barebones compared to a typical DWM setup:

We train our encoder to generate 8x8 latents and use them as the conditioning for a conditional diffusion DiT. The DiT takes in a 16x16 DCAE-latent-like tensor of gaussian noise, and diffusion-decodes it (using the 8x8 latent as conditioning) into DCAE’s original 16x16 latents. Then, we can take this denoised latent and use DCAE’s decoder to end up generating our RGB frame. No CFG! And kind-of hierarchical. This can get a bit confusing, so to recap:
- 1: We have a teacher autoencoder, which is composed of a teacher encoder (8x8 latents) and a teacher decoder (to pixels) which we call the baseline decoder.
- 2: We have a diffusion decoder (to DCAE latents) which converts the 8x8 latents to 16x16 latents that is compatible with the DCAE decoder (to pixels). When we refer to a diffusion decoder, we are referring to the translation from 8x8 to 16x16 and the decoding to pixels.
So far this diffusion decoder setup has given us the best samples we’ve seen overall, though this comes with the caveat that the decoding (involving T=20 diffusion steps) is quite heavy. At present, the best speed we can decode at is just about 1 FPS, making it a no-go for our realtime world model. Here are some samples with T=20:

The samples are quite crisp given only the 8x8 latent conditioning! It’s even more staggering when you realize that this is what the reconstructions with the teacher autoencoder look like:

We are most likely going to use a further developed version of this when we create our 720p and 1080p decoders, as these hierarchical approaches seem to be the best way forward.
Speeding up Decoding - Optimizing T(imesteps):
Due to the lack of CFG, this diffusion decoder setup lends itself very well to simplified diffusion distillation implementations. As an aside, the above samples are from our latest and largest diffusion decoder run using teacher latents from our new autoencoder with 128 channels, optical flow and depth maps. From here on out, we will focus on experiments that were done with an older teacher autoencoder that had 64 channels and only depth maps. The baseline reconstructions were far worse:

With this in mind, take a look at the reconstructions with the diffusion decoder. Keep in mind, the exact same 8x8 latents are being used, and the encoder is frozen, meaning this is a modular component that can simply be appended to the world model.

Amazingly, these samples are nearly lossless - much better compared to the teacher’s decoder! Generally speaking, the issues people have with one step diffusion approaches generally boil down to them being far less diverse and “creative” when compared to the base many step model. For us, we don’t really care about this limitation. For any given latent, there is only one “true” reconstruction, so this mode-seeking behavior is desirable!
Speeding up Decoding - Optimizing D(epth)
In a previous blog post, we spoke about Diffusion Matching Distillation (DMD), which is a technique used in a video generation paper CausVid to distill from Wan2.1 14B down to 1.3B, using a combination of a teacher, critic, and student. We won’t go into specifics here (refer to the past blogpost), but an extra detail in CausVid is that they perform ODE Regression before DMD, in order to, in their words, align the causal student with the teacher. In this sense, the “distillation” here is matching a causal model’s outputs to a bidirectional model’s output. While they apply ODE Regression and DMD to video generation, since they are methods that are made for diffusion as a whole, we apply them to our diffusion decoder too.
ODE Regression:
In plain flow-matching, you’d train a model using the instantaneous velocity in the flow between x(0) (the image) and x(1) (pure noise) on a per-sample basis. You can intuit that this would learn slowly because the task of denoising an image (or video, but let’s simplify for now) is actually several subtasks rolled into one, where a single image can actually serve as several training examples. For example, denoising at timestep 0.9 is distinct from denoising at timestep 0.1; in the former, you are required to generate the broad general details, as the image is mostly noise and nothing is really clear, but in the latter, it’s already known what’s in the image, and all that’s left is fine details (think whiskers on a cat). To this end, there is a lot of wasted computation during training, because a particular sample must be seen 10+ times for the model to reasonably learn every “subtask” for that one image.
Previously we said that a dataset of “image-like” data seems to lend itself to diffusion training the best. We speculate that the actual technical explanation is that diffusion prefers data which neatly decomposes itself into distinct subtasks. That is to say denoising at 0.9 should be a reasonably separable task from denoising at 0.1. One easy way to achieve this is by ensuring each data instance's features are spread across a variety of frequencies, such that different noise levels mask different frequency bands. Keeping the latent visibly “image-like” accomplishes this by ensuring all features of the image (low and high freq) remain visible in the latent. We find that therefore regressing a student to the intermediate scores of the teacher is a very effective approach for pruning our models as well. We detail ODE Regression below:
Start from
- 1: Pure gaussian noise
- 2: Dataset of only prompts (in our case, latents from the teacher autoencoder),
- 3: Diffusion-based teacher decoder
You run ODE regression by:
- 1: Get a batch of B latents paired with sampled noise, then noise completely for T steps.
a) This gives you BT unique input-output pairs. Larger batch sizes work better, but you can subsample from these BT pairs to make training more feasible.
b) We take 25% and set B such that 0.25BT=512, which is trainable for a small model without memory concerns.
- 2: Given each input-output pair of a denoising timestep, partially-denoised sample, and score, we train our student to simply minimize mean-squared error between its output and the teacher’s output at a certain timestep T. In this sense, you are regressing to the teacher’s ODE.
We find that ODE regression provides a far stronger training signal than the standard flow matching objective.
ODE Regression to Reduce D(epth):
Our first experiment involved randomly initializing a 16 layer DiT decoder as a student, then doing ODE regression with a 28 layer teacher, while keeping the width constant. Below are the samples from both at 1k steps. On the left of each image is the student, and on the right is the teacher.


Now for the bombshell! While the diffusion loss is finicky, ODE regression still provides a strong and learnable signal for how well a student is matching its teacher’s outputs:

In the above training chart, you see the loss curves of a diffusion decoder with only 8 layers matching a 28 layer teacher (in purple), the same but with 16 layers (brown), and a randomly initialized student with 16 layers (orange). In all the ode_tune experiments, the layer weights are copied from the teacher.
We did not train any further as we wished to move on to the next experiment, but as you can see, even at 1000 steps it was nearly matching the teacher's quality.
Now you must be wondering, “How on Earth were the shallower models better?”. Well, it would appear that if you initialized the students as just slices of the teachers layers (i.e. 14 layer model is every other layer deleted, 8 layer model is all but every 3rd layer deleted), it is able to match the teachers trajectories *much* easier. But surely for an 8 layer diffusion decoder, the samples must be garbage? Here are some comparisons. Student on left, teacher on right.


And here are some overall samples from the student:

Recapping ODE Regression Improvements:
An 8 layer model is just straight up 3.5x faster than a 28 layer one.
A caveat is that running ODE regression makes the downstream model harder to distill with DMD.
We found the 8 layer model to have middling performance with one step generation, though in fairness it *was* better than our baseline VAE:

Settling for 14 layers, we were able to get beautiful results, while achieving an overall 40x speedup compared to the original teacher diffusion decoder we started with:

What’s next for Wayfarer Labs:
We are in the process of scaling up our data pipelines and are looking for experienced folks that can help us!
We are also excited to share more about some of our 3D Consistency experiments, and our upcoming latent-action model findings :) Stay tuned!
Moving ODE regression from the diffusion decoder to the diffusion world model is not so straightforward. We will be exploring more in that direction soon as well!
If you'd like to help with any of these projects or join in on the lively discussions that actively take place on our Discord server, please join here! While you can join as a volunteer whenever you'd like, we are also taking applications for full-time positions here.