fix: Use Karras sigma schedule for Cosmos Predict2.5 SFT inference#33
fix: Use Karras sigma schedule for Cosmos Predict2.5 SFT inference#33mvanhorn wants to merge 1 commit into
Conversation
Greptile SummaryThis PR fixes degraded Cosmos Predict2.5 SFT inference quality by overriding the default linear sigma schedule with the Karras schedule (σ_max=200, σ_min=0.01, ρ=7) actually used by the official Cosmos codebase, and aligns the inference script example commands with
Confidence Score: 3/5The core sigma-schedule logic is correct and the test coverage is solid, but the sigmas tensor is left on CPU after the Karras override, which will cause a device error the first time scheduler.step() is called on a GPU. The Karras schedule math and state-reset are correctly implemented and well tested. However, _build_cosmos_predict2_karras_schedule never moves sigmas to device: it calls .to(dtype=torch.float32) but omits device=device, while the timesteps tensor is correctly moved. After the Karras override, scheduler.sigmas is a CPU tensor and scheduler.timesteps is a CUDA tensor, undoing the device placement that set_timesteps(device=noise.device) had just established. Every GPU inference run will hit a RuntimeError in scheduler.step(). The tests all use torch.device('cpu') so this is not caught. fastgen/networks/cosmos_predict2/network.py — specifically _build_cosmos_predict2_karras_schedule (sigmas device placement) and the use_karras_sigma_schedule=True default which may affect distillation configs. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant C as Caller
participant S as CosmosPredict2.sample()
participant K as _apply_cosmos_predict2_karras_schedule()
participant B as _build_cosmos_predict2_karras_schedule()
participant U as UniPCMultistepScheduler
C->>S: "sample(noise, condition, num_steps=35)"
S->>U: "set_timesteps(35, device=cuda)"
Note over U: sigmas on cuda, timesteps on cuda
S->>K: apply schedule(scheduler, 35, cuda)
K->>B: build schedule(35, 1000, cuda)
Note over B: sigmas computed on CPU
Note over B: timesteps moved to cuda
Note over B: sigmas NOT moved to cuda
B-->>K: sigmas (CPU), timesteps (cuda)
K->>U: "scheduler.sigmas = CPU tensor"
K->>U: "scheduler.timesteps = cuda tensor"
loop 36 denoising steps
S->>U: step(velocity_pred, timestep, sample)
Note over U: self.sigmas[i] is CPU
Note over U: RuntimeError on GPU
end
S-->>C: denoised latents
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant C as Caller
participant S as CosmosPredict2.sample()
participant K as _apply_cosmos_predict2_karras_schedule()
participant B as _build_cosmos_predict2_karras_schedule()
participant U as UniPCMultistepScheduler
C->>S: "sample(noise, condition, num_steps=35)"
S->>U: "set_timesteps(35, device=cuda)"
Note over U: sigmas on cuda, timesteps on cuda
S->>K: apply schedule(scheduler, 35, cuda)
K->>B: build schedule(35, 1000, cuda)
Note over B: sigmas computed on CPU
Note over B: timesteps moved to cuda
Note over B: sigmas NOT moved to cuda
B-->>K: sigmas (CPU), timesteps (cuda)
K->>U: "scheduler.sigmas = CPU tensor"
K->>U: "scheduler.timesteps = cuda tensor"
loop 36 denoising steps
S->>U: step(velocity_pred, timestep, sample)
Note over U: self.sigmas[i] is CPU
Note over U: RuntimeError on GPU
end
S-->>C: denoised latents
Reviews (1): Last reviewed commit: "fix: Use Karras sigma schedule for Cosmo..." | Re-trigger Greptile |
| sigmas = sigmas / (1 + sigmas) | ||
|
|
||
| timesteps = (sigmas * num_train_timesteps).to(device=device, dtype=torch.int64) | ||
| sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32) |
There was a problem hiding this comment.
sigmas never moved to device, breaking GPU inference
ramp is always created on CPU via torch.arange(...), so all derived sigmas computations stay on CPU. timesteps is correctly moved with .to(device=device, ...), but the final sigmas tensor that is returned (and ultimately stored in scheduler.sigmas) stays on CPU. When scheduler.step() is called during GPU inference it will index into the CPU sigmas tensor and use the resulting CPU 0-dim tensors in arithmetic with CUDA latent tensors, raising a RuntimeError: Expected all tensors to be on the same device. The prior set_timesteps(device=noise.device) call correctly places sigmas on device — this Karras override silently undoes that.
| sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32) | |
| timesteps = (sigmas * num_train_timesteps).to(device=device, dtype=torch.int64) | |
| sigmas = torch.cat([sigmas, sigmas.new_zeros(1)]).to(dtype=torch.float32, device=device) | |
| return sigmas, timesteps |
| use_wan_fp32_strategy: bool = True, | ||
| # FPS for temporal position embeddings | ||
| fps: float = 24.0, | ||
| use_karras_sigma_schedule: bool = True, |
There was a problem hiding this comment.
Default
True may silently change distillation inference
The PR description says to "gate this behind a sensible default or config flag so it only affects the SFT inference path, not distillation." Setting use_karras_sigma_schedule=True as the class-level default means every CosmosPredict2 instance — including distillation models — will use the Karras schedule unless the corresponding experiment config explicitly passes use_karras_sigma_schedule=False. If the distillation configs don't set that flag, evaluation of distilled checkpoints will silently start using a different sigma schedule, making before/after comparisons invalid.
| ) | ||
| scheduler.sigmas = sigmas | ||
| scheduler.timesteps = timesteps | ||
| scheduler.num_inference_steps = len(timesteps) |
There was a problem hiding this comment.
num_inference_steps set to num_steps + 1, denoising loop runs one extra iteration
_build_cosmos_predict2_karras_schedule uses torch.arange(num_steps + 1) which produces num_steps + 1 ramp points (including both endpoints 0 and 1). This yields a timesteps tensor of length num_steps + 1, so scheduler.num_inference_steps is set to num_steps + 1 and the denoising loop in sample() iterates num_steps + 1 times. When a user passes --num_steps 35, they will actually get 36 denoising steps. If this matches the official Cosmos reference implementation this is intentional, but it should be documented explicitly.
Summary
Cosmos-Predict2.5-2B video2world inference under FastGen produces noticeably degraded output compared to the official cosmos-predict2.5 codebase when using the same checkpoint, prompt, image, resolution, and step count. The reporter traced the primary cause to a wrong sigma/timestep schedule: the official Cosmos uses a Karras sigma schedule (sigma_max=200, sigma_min=0.01, rho=7) in its FlowUniPCMultistepScheduler.set_timesteps(use_kerras_sigma=True), while FastGen's CosmosPredict2DiT.sample() relies on the default diffusers flow-shift linear schedule from UniPCMultistepScheduler.
Changes
In fastgen/networks/cosmos_predict2/network.py, modify the sample() method (around lines 1148-1161) so that, for the Cosmos SFT sampling path, after constructing the UniPCMultistepScheduler and calling set_timesteps, it overrides sample_scheduler.sigmas and sample_scheduler.timesteps with a Karras sigma schedule (sigma_max=200, sigma_min=0.01, rho=7) exactly as the official Cosmos does, including the sigmas/(1+sigmas) reparameterization, int64 timestep truncation, and resetting the scheduler's internal solver state (model_outputs, lower_order_nums, last_sample, _step_index, _begin_index). Gate this behind a sensible default or config flag so it only affects the SFT inference path the maintainer called out, not distillation. Update the inline example command(s) in scripts/inference/video_model_inference.py to pass the Cosmos negative-prompt file and otherwise align them with scripts/README.md. Add a focused unit test in tests/test_network.py asserting the produced sigma/timestep schedule matches the reference Karras values.
Fixes #18