Fast Audio Video World Models: Attempt 1

Intro
This week at OWL, things got kind of hectic. Last week we showed a diffusion world model trained on latents from an existing autoencoder. This week we took aim at making our own autoencoder for images AND audio, then training a few step causal diffusion world model that could run at 60fps on a laptop. We were going to take this demo to CVPR and show it to people at a lunch event that we hosted. As of Sunday, we were starting from scratch. We had no prior experience with audio generation, did not have any trained VAEs that fit our specific requirements, and had never done any causal or distilled diffusion. We had three days. Did we make our demo under the time pressure? Read along to find out!
TL;DR
On Monday we started training a new landscape VAE, a new audio VAE architecture, and implemented diffusion distillation. On Tuesday we switched gears and implemented a different kind of distillation. Later on Tuesday we came up with a new architecture to solve accumulating error with KV caching to hit 100+ fps. The last two things failed, but we still trained an audio-video (AV) world model that ran at 10fps on a Laptop.
Going Into The Week: A Prelude
We shared some work we did on autoencoders last week. In the end, the specific models we trained were not ideal for a diffusion model.
Firstly, for the image autoencoder: the pixel shuffle and unshuffle used in DCAE will only work if both side lengths are even. When compressing a 16:9 (9:16 if you read height first) image, you eventually hit a size of 45x80. So what do you from here? Well our first idea was to just resize to something that could keep dividing, leaving us with 10x16 latents, which felt close to 16:9, while still allowing us to do spatial equivariance with 5x8 latents, which we discussed in last weeks post. Few problems... as we said before, more diffusable latents tend to be more "recognizable" when you do a dim reduce of the latent image. Check out this visualization of the feature maps late in the encoder from our model:

Along with the latents looking wrong, we were also sacrificing reconstruction quality to target 64 channels instead of 128. We imagined the token count increasing to 40 (5x8) from 16 (4x4) would alleviate this, but it seems like this model was even less diffusable than DCAE. The diffusion generated samples looked awful and had very clear "splotching" artifacts.

We also couldn't directly use the previous audio autoencoder. The compression was actually TOO strong. Stable audio does 2048x. That means for 2s at 44.1khz, each frame gets ~0.3 tokens so... no idea how we'd train that properly. In an ideal case we'd want 1 audio token per frame. It would also need to be setup in such a way where new latent tokens are predictable given previous ones so that the decoded generation sounds fine. 88.2k samples -> 120 samples works fine with a 735x compression ratio (strides 3->5->7->7). But we'd need it to be "causal-friendly" and equivariant for sliding windows along an hour of video.
Finally, the previous WM was slow. It was bidirectional, was not distilled, and only generated 256x256 1:1 images. We needed to fix this.
Monday: Let's Fix The Autoencoder
We figured that blindly resizing the latent super late into the model was a terrible idea and was going to naturally make the feature maps look less like the input image. Funny enough it seems that this "looks like the image" quality matters a lot more for diffusability than anything else, and it seems DCAE residuals preserve it the best for small latents. We were able to drop spatial equivariance and randomness in the latent entirely while still getting good latents and reconstructions. We found that a "fake" KL penalty, whereby we take the latent as the mean and pretend the variance is 0.1, works fine for keeping the magnitude of the latent constrained. For the aspect ratio issue, we used a learned resize right after the input convolution and right before the output convolution. To preserve generality, we simply find the nearest square resolution (in encoder) or the nearest landscape resolution (in decoder), and project to that with a learned resize+convolution. This kept the features "image-like", and allowed the DCAE residual stream to pass through untouched. While we had to cut training off early at 30k steps, the samples did not look bad!

Note that samples will appear blurred due to us only using reconstructive losses and not doing any adversarial training. We're leaving this to decoder post training stuff which we will do next week. It seems that for small latents (4x4 128 channels), the generative task is even more important. We also found pretty clearly that ConvNext is WAY better than VGG for LPIPS, and also that you should keep an eye on your LPIPS module to make sure the magnitude of the loss is 1:1 to the magnitude of MSE. For the VGG LPIPS we used this meant scaling it by 0.5, for the ConvNext, it meant scaling it by 12.0


Also Monday: The Other Autoencoder
We started from StableAudio's Oobleck encoder, but we had to make it more stable (no pun intended), as training with our auxiliary losses always went haywire. Eventually we just found that this was because of Muon:

Works great for diffusion models, but it seems we will no longer be using it for autoencoders! Let's talk more about the causal and equivariance properties we wanted: firstly, since we are encoding hours of footage, we want to be able to sample 120 frame windows from anywhere in that hour of footage. With image latents we go from Nx3x360x640 -> Nx128x4x4 so the temporal dim is unaffected. For audio, this is not the case. The main thing being compressed for audio IS the temporal dim. Stable Audio would do 88200x2 -> 40x64 (we do 88200x2->120x64 for parity with frames). The quality that we'd want on the latents is summarized in the following diagram:

Splicing latents should be equivalent to splicing the samples. This is a form of translation equivariance. To solve this, we decided to do something similar to what the spatial equivariance paper did for images, just using cutting instead of resampling:

As a secondary objective, we ensure that slices of the latent decode into equivalent slices of the waveform. For the "causal-friendly" property, we took the model based regularization approach used in the CRT paper. The CRT term did not seem to hurt the reconstructions but made it so every latent token is predictable given the previous ones. We can't really say more on this until we actually run some ablations and sweeps, but we speculate both of these design choices in conjunction should actually improve diffusability by making the latent more "audio-like" in the same way we want image latents to be more "image-like".
Also Also Monday: Diffusion Distillation
The first thing we looked into was CausVid. DMD is actually quite elegant and not super difficult to implement. If it worked, it seemed to suggest an alternative to diffusion forcing that would work with KV caching, while at the same time giving you few-step generation. Initial tests with KV caching and 1 step diffusion on an untrained model showed we could hit 100+ fps on a laptop. This gave us hope.
Tuesday: The Trouble With Distillation
The autoencoders were training very slowly. We were only able to give them 24 hours before moving on to training the world model. This allowed the image model to get to 30k steps and the audio model to get to 20k steps. We then cut the training short in order to encode our data for training with the WM. Control tensors took 10 min to generate. Audio extraction + embedding took 15 min. Video embedding was going to take... 8 hours?!?!
This put a bit of a dent in our plans. While waiting for the images to very very slowly get encoded, we started testing on our old data (from the splotchy, bad 5x8 VAE without audio). We realized we did not have time to do two step distillation as we were quickly approaching our deadline. As such, we took a big risk and implemented shortcut diffusion world models. This required some weird design changes, such as adding per-frame step embeddings along with the timestep and control embeddings. We did this so that we could do the diffusion forcing noising of history frames while generating the last frame in one step. This... did not work at all.

It seemed like near the start it was doing something then just getting worse, so in a YOLO move we designed a new architecture: a shortcut MMDiT world model architecture, that could take clean context keyframes (randomly sample from elsewhere in the video). The thinking was that seeing clean frames would prevent it from focusing on its mistakes and having a snowballing error problem. This ended up being even worse.

So running short on time, with the dataset from our new autoencoders ready, we decided to take a step back and just focus on getting an audio-video model trained, even if it was not distilled. In hindsight, we made several key mistakes with the shortcut model. The samples look similar to what you'd get if you forget to do positional encoding, so it's possible we did positional encoding wrong. We say this because we later discovered a bug in our RoPE implementation that we have since corrected.
Wednesday Morning: Alright, Let's Just Get Audio In There
With distillation failing, we went back to the basics. We stapled the audio tokens for each frame onto their image tokens, meaning we had 17 tokens per frame. We applied a 2D rope that viewed every frame as if was a row in a very tall image. Since we were panicking a little bit, we also added a learned positional encoding for good measure, though it was probably not needed.
Finally, we got samples! They weren't amazing but it was at least looking like it was moving in the right direction.

While we did not know how to log it properly initially, we promise these are all full 360p with audio! We have since started training a much much bigger (200M -> 1B) audio visual world model (AVWM) and will share samples from that later in the post, but these samples are essentially the main thing we were showing people at CVPR.
Wednesday Evening and Thursday Morning: Let's Make It Playable!
To let people try out our diffusion world model, we built a simple webapp on a fast-api webserver. It gets actions from the player and streams them directly to a locally hosted instance of the model, which then generates and sends back frames for the webapp to render. It also sends back the inputs for visualization and the audio for listening. To speed it up as much as possible, we removed CFG and reduced to 1 sampling step, which seemed to produce at least somewhat legible samples. We also compiled every model involved in the pipeline: the UViT, the image decoder, and the audio decoder. To reduce latency on the audio decoder and to leverage our translation equivariance, we only decoded the final second of audio each render, rather than decoding the full two seconds. Funny enough, with these changes it actually ran faster on the laptop (10fps) than it was running on our h200 training setup during logging (4fps). As for the audio, the only real discernible sound with this early checkpoint was a distinct screech when the fire button was pressed.
Aftermath
In the literature we have seen, the most promising way of getting few step sampling and KV caching working together seems to be self forcing. Next week's blog post will likely be on getting this implemented and working for an AVWM. Additionally, we are going to be doing some post-training stuff for the image decoder to fix the blurring you see in the samples we've shown. We also honestly need a better audio encoder as the reconstruction quality is not quite as good as we'd like right now. On Thursday, we presented the laptop demo along with the above samples at our CVPR lunch event. The lunch went great and we wanna say thanks to everyone that came!

After the lunch, we launched a larger 1B model with the same data to see how scaling effects results. TLDR: it does effect them, a lot. We also dropped UViT as it seems to be good for smaller models but hurts scaling. In hindsight, it should have been suspicious why the UViT paper reported better results with a 13 layer model than a 17 layer one. We have since fallen back to a simple DiT architecture. Check out the following samples at 50k steps!
Looking Ahead
For the next week we will focus on:
- Training a shortcut diffusion decoder for 720p reconstructions
- Training a better audio autoencoder
- Self forcing for KV caching and few step generation to finally hit our 60fps target
- Training on a mix of labelled and unlabelled data to see if one can do that (it would certainly make data collection easier if it did)
If you'd like to help with any of these projects or join in on the lively discussions that actively take place on our server, please join here! While you can join as a volunteer whenever you'd like, we are also taking applications for fulltime positions here. If you are interested in the future of entertainment experiences and want to join one of the fastest moving teams in the space, please do apply!