From Transformers & Diffusion to TransFusion

From Transformers & Diffusion to TransFusion

From Transformers & Diffusion Models to TransFusion

In recent years, Transformers and Diffusion Models have each reshaped AI—from language to images. Today, they come together in TransFusion, a model designed to generate long, realistic time-series data. Let’s dive in.

---

1. Transformers: Understanding Long Sequences

Transformers were introduced in 2017 by Vaswani et al. in their seminal paper “Attention Is All You Need”. They replaced RNNs by using self‑attention to directly relate every position in a sequence.

The Hugging Face Transformers library makes it easy to use top models:

  • BERT: bidirectional encoder for language understanding (BERT docs).
  • GPT-family: autoregressive decoder for text generation (GPT docs).
  • T5: encoder‑decoder unified text-to-text model (T5 docs).

Example: creating a Transformer encoder in PyTorch:

import torch.nn as nn

# Create a Transformer encoder with 6 layers, 8 heads, model dimension 128
encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=8)
transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
---

2. Diffusion Models: Generative Powerhouses

Diffusion models, popularized by DALL·E 2, Imagen, and Stable Diffusion, gradually add noise to data and learn to reverse the process. This results in stable and diverse generations (via Hugging Face Diffusers).

Forward (adding noise):

def forward_diffusion_sample(x0, t, alpha_t):
    """
    x0: clean time-series [batch, seq_len, dim]
    t: timestep indices
    alpha_t: noise scaling factor at t
    Returns noisy x_t and the actual noise added
    """
    noise = torch.randn_like(x0)
    x_t = torch.sqrt(alpha_t) * x0 + torch.sqrt(1 - alpha_t) * noise
    return x_t, noise

Sampling (reverse diffusion):

def sample(model, T, shape):
    x = torch.randn(shape)
    for t in reversed(range(T)):
        alpha = alpha_schedule[t]
        eps_pred = model(x, t_emb(t))
        # Reverse step: denoise
        x = (x - (1 - alpha) * eps_pred / alpha.sqrt()) / alpha.sqrt()
    return x
---

3. Why Combine Them?

On their own, they shine in their domains:

Transformers:
excellent at modeling long-range dependencies, but not generative by nature.
Diffusion Models:
strong at stable, diverse generation, yet limited in sequence structure modeling.

TransFusion merges them: a Transformer-driven denoiser inside the diffusion process, tailored for time-series generation up to 384 steps ([TransFusion paper, Sikder et al., 2023]).

---

4. Inside TransFusion

TransFusion extends DDPM by conditioning the Transformer on: (a) the noisy input xₜ, and (b) the timestep t.

4.1 Conditioning on Time

Time-conditioning helps the model understand at which noise stage it's operating. We embed the integer timestep into a vector and add or concatenate it to the Transformer input:

def t_emb(t, dim):
    # Sinusoidal positional embedding for time t
    half = dim // 2
    freqs = torch.exp(-math.log(1e4) * torch.arange(0, half) / half)
    oh = t.unsqueeze(1) * freqs.unsqueeze(0)
    return torch.cat([torch.sin(oh), torch.cos(oh)], dim=-1)

This embedding is often added token-wise or concatenated as a separate channel to xₜ.

4.2 Transformer-based Denoiser

class DiffusionTransformer(nn.Module):
    def __init__(self, dim, seq_len, layers=6, heads=8):
        super().__init__()
        self.input_proj = nn.Linear(dim, dim)
        self.pos_enc = PositionalEncoding(dim, seq_len)
        encoder = nn.TransformerEncoderLayer(d_model=dim, nhead=heads)
        self.transformer = nn.TransformerEncoder(encoder, num_layers=layers)
        self.output_proj = nn.Linear(dim, dim)

    def forward(self, x_t, t_emb):
        """
        x_t: noisy input [B, L, D]
        t_emb: time embedding [B, D]
        """
        x = self.input_proj(x_t)
        x = self.pos_enc(x)
        # Add time embedding to each time token
        t = t_emb.unsqueeze(1).expand_as(x)
        x = x + t
        x = self.transformer(x)
        return self.output_proj(x)

PositionalEncoding adds sinusoidal signals so the model knows the order within the sequence, and adding t_emb integrates timestep context.

4.3 Training Loop

for epoch in range(epochs):
    x0 = get_batch() # [B, L, D] real data
    t = torch.randint(0, T, (B,))
    alpha = alpha_schedule(t)
    x_t, noise = forward_diffusion_sample(x0, t, alpha)

    t_embedding = t_emb(t, D)
    pred_noise = model(x_t, t_embedding)

    loss = F.mse_loss(pred_noise, noise)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

The Transformer learns to predict noise at each timestep—this is the core diffusion training objective.

4.4 Generating New Sequences

def sample(model, T, shape):
    x = torch.randn(shape)  # start from full noise
    for t in reversed(range(T)):
        alpha = alpha_schedule[t]
        t_embedding = t_emb(torch.full((shape[0],), t), shape[2])
        eps_pred = model(x, t_embedding)
        x = (x - (1 - alpha) * eps_pred / alpha.sqrt()) / alpha.sqrt()
    return x
---

5. Evaluating TransFusion

The paper introduces two novel scores:

  • LDS: Train a classifier to distinguish real vs synthetic sequences—lower accuracy = better realism.
  • LPS: Train a model on generated data and test on real—lower MAE = better generalization.

They also use PCA/t-SNE plots and diversity metrics (precision, recall, coverage), showing TransFusion outperforms prior GANs and diffusion models (TransFusion results).

---

6. Final Thoughts

TransFusion combines:

  1. Transformers to capture sequence structure
  2. Diffusion’s step-wise generation with robust diversity
  3. Time-conditioning for effective denoising

For those working with generating synthetic time-series in finance, healthcare, or IoT data, this is a major leap forward.

Explore the full paper: TransFusion: Generating Long, High Fidelity Time Series using Diffusion Models with Transformers.

Comments

Popular posts from this blog

Building and Deploying a Recommender System on Kubeflow with KServe

Tutorial: Building Login and Sign-Up Pages with React, FastAPI, and XAMPP (MySQL)

CrewAI vs LangGraph: A Simple Guide to Multi-Agent Frameworks