Skip to content

fix: Use Karras sigma schedule for Cosmos Predict2.5 SFT inference#33

Open
mvanhorn wants to merge 1 commit into
NVlabs:mainfrom
mvanhorn:fix/18-cosmos-predict2-karras-sigma-schedule
Open

fix: Use Karras sigma schedule for Cosmos Predict2.5 SFT inference#33
mvanhorn wants to merge 1 commit into
NVlabs:mainfrom
mvanhorn:fix/18-cosmos-predict2-karras-sigma-schedule

Conversation

@mvanhorn

Copy link
Copy Markdown

Summary

Cosmos-Predict2.5-2B video2world inference under FastGen produces noticeably degraded output compared to the official cosmos-predict2.5 codebase when using the same checkpoint, prompt, image, resolution, and step count. The reporter traced the primary cause to a wrong sigma/timestep schedule: the official Cosmos uses a Karras sigma schedule (sigma_max=200, sigma_min=0.01, rho=7) in its FlowUniPCMultistepScheduler.set_timesteps(use_kerras_sigma=True), while FastGen's CosmosPredict2DiT.sample() relies on the default diffusers flow-shift linear schedule from UniPCMultistepScheduler.

Changes

In fastgen/networks/cosmos_predict2/network.py, modify the sample() method (around lines 1148-1161) so that, for the Cosmos SFT sampling path, after constructing the UniPCMultistepScheduler and calling set_timesteps, it overrides sample_scheduler.sigmas and sample_scheduler.timesteps with a Karras sigma schedule (sigma_max=200, sigma_min=0.01, rho=7) exactly as the official Cosmos does, including the sigmas/(1+sigmas) reparameterization, int64 timestep truncation, and resetting the scheduler's internal solver state (model_outputs, lower_order_nums, last_sample, _step_index, _begin_index). Gate this behind a sensible default or config flag so it only affects the SFT inference path the maintainer called out, not distillation. Update the inline example command(s) in scripts/inference/video_model_inference.py to pass the Cosmos negative-prompt file and otherwise align them with scripts/README.md. Add a focused unit test in tests/test_network.py asserting the produced sigma/timestep schedule matches the reference Karras values.

Fixes #18

@greptile-apps

greptile-apps Bot commented Jun 17, 2026

Copy link
Copy Markdown

Greptile Summary

This PR fixes degraded Cosmos Predict2.5 SFT inference quality by overriding the default linear sigma schedule with the Karras schedule (σ_max=200, σ_min=0.01, ρ=7) actually used by the official Cosmos codebase, and aligns the inference script example commands with scripts/README.md.

  • Adds _build_cosmos_predict2_karras_schedule / _apply_cosmos_predict2_karras_schedule helpers that replicate the official Karras ramp and reparameterization, then reset the UniPC solver state; gated by a use_karras_sigma_schedule flag that defaults to True.
  • Updates all inline example commands in video_model_inference.py to include --num_steps, --fps, and --neg_prompt_file; the Cosmos V2W example gains model.input_shape and the Cosmos-specific negative prompt.
  • Adds three unit tests covering sigma value correctness, scheduler state reset, and a snapshot check on the example command.

Confidence Score: 3/5

The core sigma-schedule logic is correct and the test coverage is solid, but the sigmas tensor is left on CPU after the Karras override, which will cause a device error the first time scheduler.step() is called on a GPU.

The Karras schedule math and state-reset are correctly implemented and well tested. However, _build_cosmos_predict2_karras_schedule never moves sigmas to device: it calls .to(dtype=torch.float32) but omits device=device, while the timesteps tensor is correctly moved. After the Karras override, scheduler.sigmas is a CPU tensor and scheduler.timesteps is a CUDA tensor, undoing the device placement that set_timesteps(device=noise.device) had just established. Every GPU inference run will hit a RuntimeError in scheduler.step(). The tests all use torch.device('cpu') so this is not caught.

fastgen/networks/cosmos_predict2/network.py — specifically _build_cosmos_predict2_karras_schedule (sigmas device placement) and the use_karras_sigma_schedule=True default which may affect distillation configs.

Important Files Changed

Filename Overview
fastgen/networks/cosmos_predict2/network.py Adds Karras sigma schedule for Cosmos Predict2.5 SFT inference; the sigmas tensor is not moved to the requested device, which will crash GPU inference in scheduler.step(). The default-True flag may also silently change distillation eval behavior.
scripts/inference/video_model_inference.py Updates inline example commands: adds --num_steps, --fps, and --neg_prompt_file to all examples; aligns Cosmos V2W example with scripts/README.md.
tests/test_network.py Adds three focused unit tests for the Karras schedule. Tests run on CPU only, so the device-mismatch bug is not caught.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant C as Caller
    participant S as CosmosPredict2.sample()
    participant K as _apply_cosmos_predict2_karras_schedule()
    participant B as _build_cosmos_predict2_karras_schedule()
    participant U as UniPCMultistepScheduler

    C->>S: "sample(noise, condition, num_steps=35)"
    S->>U: "set_timesteps(35, device=cuda)"
    Note over U: sigmas on cuda, timesteps on cuda
    S->>K: apply schedule(scheduler, 35, cuda)
    K->>B: build schedule(35, 1000, cuda)
    Note over B: sigmas computed on CPU
    Note over B: timesteps moved to cuda
    Note over B: sigmas NOT moved to cuda
    B-->>K: sigmas (CPU), timesteps (cuda)
    K->>U: "scheduler.sigmas = CPU tensor"
    K->>U: "scheduler.timesteps = cuda tensor"
    loop 36 denoising steps
        S->>U: step(velocity_pred, timestep, sample)
        Note over U: self.sigmas[i] is CPU
        Note over U: RuntimeError on GPU
    end
    S-->>C: denoised latents
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant C as Caller
    participant S as CosmosPredict2.sample()
    participant K as _apply_cosmos_predict2_karras_schedule()
    participant B as _build_cosmos_predict2_karras_schedule()
    participant U as UniPCMultistepScheduler

    C->>S: "sample(noise, condition, num_steps=35)"
    S->>U: "set_timesteps(35, device=cuda)"
    Note over U: sigmas on cuda, timesteps on cuda
    S->>K: apply schedule(scheduler, 35, cuda)
    K->>B: build schedule(35, 1000, cuda)
    Note over B: sigmas computed on CPU
    Note over B: timesteps moved to cuda
    Note over B: sigmas NOT moved to cuda
    B-->>K: sigmas (CPU), timesteps (cuda)
    K->>U: "scheduler.sigmas = CPU tensor"
    K->>U: "scheduler.timesteps = cuda tensor"
    loop 36 denoising steps
        S->>U: step(velocity_pred, timestep, sample)
        Note over U: self.sigmas[i] is CPU
        Note over U: RuntimeError on GPU
    end
    S-->>C: denoised latents
Loading

Reviews (1): Last reviewed commit: "fix: Use Karras sigma schedule for Cosmo..." | Re-trigger Greptile

sigmas = sigmas / (1 + sigmas)

timesteps = (sigmas * num_train_timesteps).to(device=device, dtype=torch.int64)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 sigmas never moved to device, breaking GPU inference

ramp is always created on CPU via torch.arange(...), so all derived sigmas computations stay on CPU. timesteps is correctly moved with .to(device=device, ...), but the final sigmas tensor that is returned (and ultimately stored in scheduler.sigmas) stays on CPU. When scheduler.step() is called during GPU inference it will index into the CPU sigmas tensor and use the resulting CPU 0-dim tensors in arithmetic with CUDA latent tensors, raising a RuntimeError: Expected all tensors to be on the same device. The prior set_timesteps(device=noise.device) call correctly places sigmas on device — this Karras override silently undoes that.

Suggested change
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32)
timesteps = (sigmas * num_train_timesteps).to(device=device, dtype=torch.int64)
sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32, device=device)
return sigmas, timesteps

use_wan_fp32_strategy: bool = True,
# FPS for temporal position embeddings
fps: float = 24.0,
use_karras_sigma_schedule: bool = True,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Default True may silently change distillation inference

The PR description says to "gate this behind a sensible default or config flag so it only affects the SFT inference path, not distillation." Setting use_karras_sigma_schedule=True as the class-level default means every CosmosPredict2 instance — including distillation models — will use the Karras schedule unless the corresponding experiment config explicitly passes use_karras_sigma_schedule=False. If the distillation configs don't set that flag, evaluation of distilled checkpoints will silently start using a different sigma schedule, making before/after comparisons invalid.

)
scheduler.sigmas = sigmas
scheduler.timesteps = timesteps
scheduler.num_inference_steps = len(timesteps)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 num_inference_steps set to num_steps + 1, denoising loop runs one extra iteration

_build_cosmos_predict2_karras_schedule uses torch.arange(num_steps + 1) which produces num_steps + 1 ramp points (including both endpoints 0 and 1). This yields a timesteps tensor of length num_steps + 1, so scheduler.num_inference_steps is set to num_steps + 1 and the denoising loop in sample() iterates num_steps + 1 times. When a user passes --num_steps 35, they will actually get 36 denoising steps. If this matches the official Cosmos reference implementation this is intentional, but it should be documented explicitly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Cosmos Predict2.5 inference quality mismatch with official codebase — wrong sigma schedule

1 participant