Autoencoders for Diffusion: A Deep Dive

Another Week Another Blog
As we close our third week, we've made significant progress in Autoencoder and Video Diffusion training. Currently we are working on autoencoder post-training concurrently with video model training on the latents produced by the aforementioned autoencoder. In this blog post we will summarize some of our findings with autoencoders as well as some null results we had with a slightly unconventional approach we tried.
TL;DR
We compressed 720x1280 images to 10x16 while keeping them easy to learn for a diffusion model. We are going to push further to 4x4 soon. We resolved numerical issues with other methods to keep training fast and stable. We found that compressing to 1D with a transformer, i.e. just 16 flat vectors, results in good reconstructions that cannot be diffused. Want to train your own VAEs with our code? See here.

What's An Autoencoder?
Video game frames are really big. Ideally we want to do 720p. 720x1280 resolution would mean, assuming a reasonable patch size of 8x8, our World Model is seeing 14,400 tokens per frame. That is: nearly 1M tokens for 1 second at 60fps. Yikes. To resolve this, we use an autoencoder. If you're reading this far we're going to assume you know what an autoencoder is, but as a TL;DR it's a model that encodes big images into small images, then decodes those small images into big images. The blockbuster paper on modern image generation brought autoencoders into the spotlight when they found that, so long as you compress an image to a very small size, diffusion models are really good at making nice images. But as research has progressed, several questions have risen. How small can you make this latent? How many channels should it have? What makes one autoencoder good for diffusion models and another bad? We're going to do our best in answering these questions for this blog post. But we need to do this step by step. First lets talk about GANs.
GANs And The Modern Autoencoder
Don't we use diffusion now? Why are GANs relevant? Check out these samples from our 720x1280 -> 10x16 autoencoder:

Observe how the reconstructions have a very obvious "blurring" going on. This effect becomes more and more pronounced when you push the latents smaller and smaller. MSE tends to focus too much on the precise positions of things in images, rather than their content, so autoencoders naturally struggle with high frequency 'fine' details. Perceptual losses (we use Conv Next LPIPS) help a little bit but as you can see the blurriness remains. To this end, all modern autoencoders (starting with VQGAN) include an adversarial training phase. The idea is that the GAN loss encourages the autoencoder to generate things it can't reconstruct from its latent. Consider the example of encoding an image of a tree. All trees have leaves and bark, which generally have the same texture/structure. Encoding fine grained details about the bark or leaves wastes valuable bandwidth in the latent space. The DCAE paper shows a very good example of this:

Given the GAN loss is basically telling the model to generate specific details rather than encode them, and we want to be frugal with latent bandwidth, we freeze the encoder during adversarial training in all our GAN training runs. To stay up-to-date on literature we follow a scheme similar to R3GAN, but we will give more details on this shortly.
Precision Errors In Your VAE? More Likely Than You Think
Floating point errors get worse with larger numbers:

As such, more constrained activations are naturally desirable if you want to do lower precision for training/inference speed while preserving stability and not being NaN'd into the moon. Don't believe in the danger of NaNs? Check out these activations (absolute values) from the last decoder layer in the SDXL VAE:

Now this is baaaaad. At FP32 the image decodes fine, but if you try to cast the base model to fp16 and run a forward pass you will just straight up get NaNs instead of an image. This problem of exploding activations still exists with other models. Intriguingly, a big part of the DCAE paper is that they separated the adversarial training, citing stability concerns. I see the same issue in the feature maps from DCAEs decoder:

It's far less severe but its still not ideal given what we know about floating point rounding errors. As for their stability concerns, when I tried finetuning DCAE's decoder with an adversarial term, it just NaN'd immediately. Now one point worth bringing up is that it seems to be common knowledge that BF16 is unstable for adversarial training. In the majority of VAE papers, there seems to be no mention of precision. To circumvent instability, many fall back to FP32, reducing training throughput significantly. To us, this isn't a good enough solution. We need more speed!
Following the recipe in R3GAN, we use Fix-Up initialization in the residual blocks within our autoencoder. This means scaling the weights of the convolutions by a factor correlated to the overall number of residual blocks, and the output convolution weights are initialized to zero. Effectively this makes the model start as a sort of "identity", with the residual stream passing through every layer untouched. Of course, when you're downsampling/upsampling, you can't really leave the residual stream untouched. Pixel shuffle and light unbiased 1x1 convolutions can sort of do this, so we take the space to channel and channel to space operations from DCAE and integrate that into our custom architecture. If you're interested in implementation details, I will leave a link to our residual block implementation here. The end result of this was far more constrained activations in our decoder:

Additionally, check out these smooth GAN training graphs:

No NaNs at BF16! Though I will say it seems the discriminator won and training collapsed. To alleviate this we added the R1 and R2 penalties from R3GAN and that seems to have stabilized training:

Diffusability and Smoothness
Creating a good autoencoder to plug into diffusion models isn't just about reconstruction quality - the latent space needs to be "diffusable," meaning it should be easy for diffusion models to learn and generate from. The properties of your latent space directly impact how well downstream generative models can work with your representations; if your latents are too noisy, have poor spectral characteristics, or contain artifacts from the encoding process, diffusion models will struggle to learn meaningful patterns and generate coherent samples. This is why we focus heavily on ensuring our latents maintain the right properties to make them amenable to diffusion-based generation!
One key aspect of our approach involved implementing spectral equivariance through careful downsampling strategies, following insights from the diffusability of autoencoders paper. This technique helps preserve the spectral characteristics that make latents easier for diffusion models to work with.
Interestingly, we found that contrary to some literature suggesting that increasing channel counts makes autoencoders less diffusable, a channel count of 128 works quite well in our experiments. The key insight from our PCA analysis is that latents should maintain spectral characteristics similar to the original images - they should preserve the distribution of different frequencies that make them "image-like" - see below!

PCA Map for SDXL's encoder

PCA map for our own encoder

PCA map for DCAE's encoder
Intuitively, we find that as you compress images to smaller latent sizes, the generative task starts to dominate more since the latent space becomes increasingly constrained. This means the model needs to hallucinate more details during reconstruction. Part of our approach involves being careful not to let GAN-specific details leak into the latent space, as this could make the representations less image-like and harder for diffusion models to work with.
It's also worth noting that KL divergence penalties are less commonly used in modern approaches. Instead, many researchers now simply add noise to the inputs during decoder training. We decided against using either approach, figuring that L2 regularization during decoder fine-tuning and post-training should accomplish similar goals without the added complexity.
We're hard at work understanding and thinking about the design of latent spaces to improve diffusability here at Open World Labs, and we're looking for contributors! If you're interested, feel free to join our Discord to get involved!
Post-Mortem: Proxy Models
Another promising approach was Proxy Models, introduced by the TiToK VQ-VAE image tokenizer paper. The idea is to first train a VQ-VAE to compress a 512x512 image down to 16x16 tokens. Then, train a separate transformer-based VAE that further compresses those 16x16 representations down to just 16 flat tokens (and decoding back to the 16x16 input representation), for a staggering 128x compression!
Using Proxy Models, we achieved better image reconstruction than anything we could produce by using 4x4 latents directly. See below!

TiToK Reconstructions with the un-compressed VAE

TiToK reconstructions with the proxy model (128x compression, 16 total patches) - very high quality reconstructions!

TiToK generation with a rectified flow pixel reconstructor
Unfortunately, this approach didn't work well with diffusion models, as shown in these video generation samples below!
We also experimented with adding a sparsity loss term to the TiToK proxy models. Interestingly, the model learned to effectively ignore about 50% of the 128 latent channels - these values were essentially zeroed out due to the sparsity penalty. We could then manually zero out these unused channels during decoding with no loss in reconstruction quality. However, when we tried training a model with only 64 channels from the start (rather than learning to ignore half of 128 channels), it failed completely. This suggests that having the extra capacity during training is crucial, even if the final learned representation is sparser.

TiToK reconstruction with a sparsity term
Audio Autoencoders
We also experimented with audio autoencoders, although we ran into challenges with the temporal compression ratios. Most existing audio autoencoders (like StableAudio's VAE) compress audio so aggressively that they're unsuitable for our world-modeling use case, but still provide good results! See the reconstruction of the original audio.
For our World Model, we need each video frame to have an associated audio sample, but these autoencoders would result in audio tokens only every third frame or so. It seems the temporal compression for audio is simply too aggressive for our frame-by-frame approach! (for now)
A Look Ahead + Acknowledgements
If everything goes as planned, we should have some audio-video models to share with you next week, so stay tuned! Additionally we want to acknowledge @SwayStar123 for an efficient implementation of the R3GAN penalties, @neel04 for adding audio to Owl VAEs, and @autoregression for contributing to fruitful discussions on the server that substantially improved the image autoencoder.