Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion

* Work done while being a visiting student at MIT.

  • 1MIT

TL;DR: Diffusion Forcing combines the strength of full-sequence diffusion models and next-token models, acting as either or a mix at sampling time for different applications without retraining.

Abstract

This paper presents Diffusion Forcing, a new training paradigm where a diffusion model is trained to denoise a set of tokens with independent per-token noise levels. We apply Diffusion Forcing to sequence generative modeling by training a causal next-token prediction model to generate one or several future tokens without fully diffusing past ones. Our approach is shown to combine the strengths of next-token prediction models, such as variable-length generation, with the strengths of full-sequence diffusion models, such as the ability to guide sampling to desirable trajectories. Our method offers a range of additional capabilities, such as (1) rolling-out sequences of continuous tokens, such as video, with lengths past the training horizon, where baselines diverge and (2) new sampling and guiding schemes that uniquely profit from Diffusion Forcing's variable-horizon and causal architecture, and which lead to marked performance gains in decision-making and planning tasks. In addition to its empirical success, our method is proven to optimize a variational lower bound on the likelihoods of all subsequences of tokens drawn from the true joint distribution.

Diffusion Forcing

The name "Diffusion Forcing" comes from "teacher forcing" and "diffusion models".

Diffusion Forcing enjoys key strengths of both next-token autoregressive models and full-sequence diffusion models. By training Diffusion Forcing once, one can flexibly control its behavior at sampling time to simultaneously perform flexible and compositional geneation like next-token models, and perform sequence level guidance like full-sequence diffusion models.


Abilities of teacher forcing, full-sequence diffusion, and Diffusion Forcing.

Diffusion Forcing achieves so by training sequence diffusion but allowing each token to have a different noise level. One can view noises in diffusion as varying levels of masking and establish a unified view: full-sequence diffusion denoise all frames at once with the same noise level, while next-token prediction denoises next frame at a time with zero noise in its past tokens.


Diffusion Forcing method.

As a result, one can use different noise levels across a sequence at sampling time to achieve flexible behaviors such as stablizing auto-regressive rollout, guidance over long horizon or planning with causal uncertainty.


Diffusion Forcing usage.

Video Prediction

We provide a list of synthesized videos directly generated by models (without VAE / superresolution). The below results are sampled without cherry-picking.

Video Prediction by Diffusion Forcing (ours) and baselines in DMLab dataset (0.25x speed). Teacher forcing easily blows up while causal full-sequence diffusion models suffer from serious consistency issues. Diffusion Forcing can achieve stable and and consistent video prediction. PNG visualizations are provided below to reflect the original quality of generated samples.


Video Prediction by Diffusion Forcing (ours) and baselines in Minecraft dataset (0.5x speed). Teacher forcing easily blows up while causal full-sequence diffusion models suffer from serious consistency issues. Diffusion Forcing can achieve stable and and consistent video prediction. PNG visualizations are provided below to reflect the original quality of generated samples.


Diffusion Planning

Similar to prior works like Diffuser, we can use test-time guidance to make our diffusion sequence a planner. However, we explictly model the causal relationship by defining each token as [a_t, o_{t+1}]. By doing so, we have a belief over action to take and the observation it's leading to, but can also update this belief to posterior estimation when new observation is made after the action is taken.

Visualization of the diffusion planning process of Diffusion Forcing as a decision-making framework. To model the causal uncertainty of future, diffusion forcing's plan can have near future at lower noise level while having far future at higher noise level.

Long Horizon Imitation Learning

Many real world tasks are not markovian and requires long horizon memory to accomplish. In our real robot task, a robot arm is asked to swap the slots of two fruits using a third slot. Since the fruits are input in random slots at the beginning, one cannot determine the next steps from a single observation without knowledge of the initial placement of the fruits.

We simply remove guidance from the planning experiments and jointly diffuses action-observation sequences to perform feedback control.

The above video shows multiple continuous successes before a failure happens. One can observe that the robot is able to accomplish the task even when the fruit location is randomized by the previous run. On the other hand, we tried SOTA imitation learning techniques Diffusion Forcing but it cannot perform the task due to non-markovianess.

In addition, diffusion forcing can be prompted to treat incoming observation as noisy ones to be robust to unseen distractions at test time. In the video above, we illustrate our distraction method of randomly throwing a shopping bag into the field of view.

Stablizing Infinite Rollout without Sliding Window

In addition, one can rollout much longer videos with our method than the maximum sequence length it's trained on. Remarkly, we can do this without Sliding Window. That is, we rollout RNN without ever resetting the latent z to initial latent z0, showing stablization effect of Diffusion Forcing thanks to its stablization effect. Videos are compressed for loading speed. The results are sampled without cherry-picking.

Quality of the video is decreased due to mp4 compression of long videos! We provide PNG visualizations below to reflect original quality of generated samples longer than training horizon.

Diffusion Forcing (ours) trained on 36 frames can rollout for 2000 frames or more on DMLab dataset, without sliding window thanks to its stablization effect. Videos are compressed for loading speed. Original dataset resolution is 64x64.

Quality of the video is decreased due to mp4 compression of long videos! We provide PNG visualizations below to reflect original quality of generated samples longer than training horizon.

Diffusion Forcing (ours) trained on 72 frames rolloutss for 2000 frames or more on Minecraft dataset without blowing up, without sliding window. Original dataset resolution is 128x128. In certain scenarios, the agent will get stuck in front of two block high dirt or stone blocks until it switches direction, which is an instrinsics issue of the dataset collection.

Suggested Directions for Future Work

Conditioning: When people extend a sequence diffusion model to longer length than it's trained on, conditioing by replacement is often used. In his paper "Video Diffusion Models", Johnathan Ho discusses why this is wrong. Instead, diffusion Forcing tells the model to treat context tokens as clean and future tokens as noisy, which is a more natural way to do conditioning but we haven't explored this in detail.

Noise as masking: Noise as masking achieves fractional masking of tokens instead of a discrete binary masking. This is general enough to be put in many self-supervised learning methods like MAE. Since adding by noise have interesting interpretations on frequency domain, this could be interesing to explore.

Compositionality: In our paper, we show compositionality can be achieved by controlling the history length. However, with noise as masking, it's possible for the model to figure out when to ignore uncessary history and only condition on shorter horizon itself.

Non-causal version: Diffusion Forcing is causal as in this paper because causality is important for decision making. However, the idea of noise as masking is applicable in non-causal models as well. In fact, you can potentially train a non-causal version and make it causal at sampling time! To do so, you can just mask entries that you don't want a prediction to see with pure gaussian noise.

Alternative Guidance: In our paper's proposed decision making framework, we did guidance on observation to keep the setting closer to diffuser. However, we also proposed a version where we do guidance on learned reward but haven't explored it in the paper.

Noise scheme: The idea of independent noise level per token is designed to be general, but not necessarily optimal for every task. E.g. It could retain too much redundency if the data is very locally correlated on time axis. This can affect the overall signal-to-noise ratio. It's interesting to explore different noise schemes.

Next few token prediction: Next few token prediction is only used in our planning experiment, and video experiment is still next-token. It didn't work super well in RNN version but we find it to work very well in our transformer version of code. One observation we find is that when using a causal model, doing next few token prediction can lead to inconsistency if the "few" is very big. This doesn't happen as much for non-causal model. There are interesting scientific questions to study why.

Latent & DiT version: We released a 3D Unet Version of Diffusion Forcing after the release. However, Diffusion Forcing shall be applicable to DiT as well, causal or non-causal. In addition, the stablization scheme makes more sense in latent space with VAE, because corruption on pixel is not necessarilty gaussian while that on VAE latent shall be closer to gaussian.

BibTeX


@misc{chen2024diffusionforcingnexttokenprediction,
      title={Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion}, 
      author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
      year={2024},
      eprint={2407.01392},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.01392}, 
}