Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion

* Work done while being a visiting student at MIT.

  • 1MIT

TL;DR: Diffusion Forcing combines the strength of full-sequence diffusion models and next-token models, acting as either or a mix at sampling time for different applications without retraining.

Abstract

This paper presents Diffusion Forcing, a new training paradigm where a diffusion model is trained to denoise a set of tokens with independent per-token noise levels. We apply Diffusion Forcing to sequence generative modeling by training a causal next-token prediction model to generate one or several future tokens without fully diffusing past ones. Our approach is shown to combine the strengths of next-token prediction models, such as variable-length generation, with the strengths of full-sequence diffusion models, such as the ability to guide sampling to desirable trajectories. Our method offers a range of additional capabilities, such as (1) rolling-out sequences of continuous tokens, such as video, with lengths past the training horizon, where baselines diverge and (2) new sampling and guiding schemes that uniquely profit from Diffusion Forcing's variable-horizon and causal architecture, and which lead to marked performance gains in decision-making and planning tasks. In addition to its empirical success, our method is proven to optimize a variational lower bound on the likelihoods of all subsequences of tokens drawn from the true joint distribution.

Diffusion Forcing

The name "Diffusion Forcing" comes from "teacher forcing" and "diffusion models".

Diffusion Forcing enjoys key strengths of both next-token autoregressive models and full-sequence diffusion models. By training Diffusion Forcing once, one can flexibly control its behavior at sampling time to simultaneously perform flexible and compositional geneation like next-token models, and perform sequence level guidance like full-sequence diffusion models.


Abilities of teacher forcing, full-sequence diffusion, and Diffusion Forcing.

Diffusion Forcing achieves so by training sequence diffusion but allowing each token to have a different noise level. One can view noises in diffusion as varying levels of masking and establish a unified view: full-sequence diffusion denoise all frames at once with the same noise level, while next-token prediction denoises next frame at a time with zero noise in its past tokens.


Diffusion Forcing method.

As a result, one can use different noise levels across a sequence at sampling time to achieve flexible behaviors such as stablizing auto-regressive rollout, guidance over long horizon or planning with causal uncertainty.


Diffusion Forcing usage.

Video Prediction

We provide a list of synthesized videos directly generated by models (without VAE / superresolution). The below results are sampled without cherry-picking.

Video Prediction by Diffusion Forcing (ours) and baselines in DMLab dataset (0.25x speed). Teacher forcing easily blows up while causal full-sequence diffusion models suffer from serious consistency issues. Diffusion Forcing can achieve stable and and consistent video prediction. PNG visualizations are provided below to reflect the original quality of generated samples.



Video Prediction by Diffusion Forcing (ours) and baselines in Minecraft dataset (0.5x speed). Teacher forcing easily blows up while causal full-sequence diffusion models suffer from serious consistency issues. Diffusion Forcing can achieve stable and and consistent video prediction. PNG visualizations are provided below to reflect the original quality of generated samples.


Stablizing Infinite Rollout without Sliding Window

In addition, one can rollout much longer videos with our method than the maximum sequence length it's trained on. Remarkly, we can do this without Sliding Window. That is, we rollout RNN without ever resetting the latent z to initial latent z0, showing stablization effect of Diffusion Forcing thanks to its stablization effect. Videos are compressed for loading speed. The results are sampled without cherry-picking.

Quality of the video is decreased due to mp4 compression of long videos! We provide PNG visualizations below to reflect original quality of generated samples longer than training horizon.

Diffusion Forcing (ours) trained on 36 frames can rollout for 2000 frames or more on DMLab dataset, without sliding window thanks to its stablization effect. Videos are compressed for loading speed. Original dataset resolution is 64x64.

Quality of the video is decreased due to mp4 compression of long videos! We provide PNG visualizations below to reflect original quality of generated samples longer than training horizon.

Diffusion Forcing (ours) trained on 72 frames rolloutss for 2000 frames or more on Minecraft dataset without blowing up, without sliding window. Original dataset resolution is 128x128. In certain scenarios, the agent will get stuck in front of two block high dirt or stone blocks until it switches direction, which is an instrinsics issue of the dataset collection.


Diffusion Planning

Similar to prior works like Diffuser, we can use test-time guidance to make our diffusion sequence a planner. However, we explictly model the causal relationship by defining each token as [a_t, o_{t+1}]. By doing so, we have a belief over action to take and the observation it's leading to, but can also update this belief to posterior estimation when new observation is made after the action is taken.

Visualization of the diffusion planning process of Diffusion Forcing as a decision-making framework. To model the causal uncertainty of future, diffusion forcing's plan can have near future at lower noise level while having far future at higher noise level.

Long Horizon Imitation Learning

Many real world tasks are not markovian and requires long horizon memory to accomplish. In our real robot task, a robot arm is asked to swap the slots of two fruits using a third slot. Since the fruits are input in random slots at the beginning, one cannot determine the next steps from a single observation without knowledge of the initial placement of the fruits.

We simply remove guidance from the planning experiments and jointly diffuses action-observation sequences to perform feedback control.

The above video shows multiple continuous successes before a failure happens. One can observe that the robot is able to accomplish the task even when the fruit location is randomized by the previous run. On the other hand, we tried SOTA imitation learning techniques Diffusion Forcing but it cannot perform the task due to non-markovianess.

In addition, diffusion forcing can be prompted to treat incoming observation as noisy ones to be robust to unseen distractions at test time. In the video above, we illustrate our distraction method of randomly throwing a shopping bag into the field of view.

BibTeX


@misc{chen2024diffusionforcingnexttokenprediction,
      title={Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion}, 
      author={Boyuan Chen and Diego Marti Monso and Yilun Du and Max Simchowitz and Russ Tedrake and Vincent Sitzmann},
      year={2024},
      eprint={2407.01392},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2407.01392}, 
}