Refine static NVFP4 MSE calibration#1536
Conversation
Signed-off-by: realAsma <akuriparambi@nvidia.com> minor Signed-off-by: realAsma <akuriparambi@nvidia.com> Cache NVFP4 MSE amax in fp32 Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR refactors NVFP4 MSE calibration to cache final per-block amax immediately as float32 in a one-shot cycle, centralizes max-stat collection and per-weight MSE calibrator dispatch, tightens NVFP4-static promotion and global-amax sync, updates TensorQuantizer amax buffer handling, and expands tests for dtype and one-shot semantics. ChangesNVFP4 MSE Calibration and Static Promotion Refactor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 214-221: The stats lifecycle in _run_and_load_max_stats is not
guarded: call enable_stats_collection(model) then run the forward path (either
weight_only_quantize(model) or forward_loop(model)) inside a try block and call
finish_stats_collection(model) in a finally block so finish_stats_collection
always executes even if the forward path raises; re-raise any caught exception
after the finally to preserve behavior. Reference functions:
_run_and_load_max_stats, enable_stats_collection, weight_only_quantize,
forward_loop, finish_stats_collection.
In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py`:
- Line 295: The local import of TensorQuantizer inside the
test_mse_calibrate_end_to_end function should be moved to module scope: remove
the in-function import and add "from modelopt.torch.quantization.nn import
TensorQuantizer" to the top of the test file with the other imports so import
failures surface at collection time; update any references in
test_mse_calibrate_end_to_end to use the now-module-level TensorQuantizer and
ensure there is no justification comment left for an inside-function import.
In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 686-700: Move the in-test imports of
_promote_nvfp4_static_quantizers_with_global_amax_sync out of the individual
test methods and place them in the module-level import block (i.e., import
_promote_nvfp4_static_quantizers_with_global_amax_sync from
modelopt.torch.quantization.model_calib at the top of the test file) so tests
follow the guideline that imports belong at file scope; only keep them inside a
test if there is a documented circular/optional dependency reason.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 39296e29-e4e1-4049-9048-513405b3ee9d
📒 Files selected for processing (12)
examples/llm_ptq/cast_mxfp4_to_nvfp4.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/utils/core_utils.pymodelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yamltests/gpu/torch/quantization/test_gptq.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.pytests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.pytests/unit/torch/quantization/plugins/test_fused_experts.pytests/unit/torch/quantization/test_mse_calibrator.py
💤 Files with no reviewable changes (2)
- tests/gpu/torch/quantization/test_gptq.py
- modelopt/torch/quantization/utils/core_utils.py
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## main #1536 +/- ##
==========================================
- Coverage 76.75% 69.13% -7.62%
==========================================
Files 476 477 +1
Lines 51811 53007 +1196
==========================================
- Hits 39767 36649 -3118
- Misses 12044 16358 +4314
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Regarding CodeRabbit’s stats-lifecycle suggestion at #1536 (comment): I am going to leave this as-is. The helper currently has a simple, linear stats lifecycle, and adding a |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)
523-536:⚠️ Potential issue | 🟠 Major | ⚡ Quick winRestore the original calibrator after the temporary MSE pass.
This loop leaves
weight_quantizer._calibratorpointing atcal, then immediately resets it. For the baseMseCalibrator,reset()clears_initial_amax, so a later calibration pass on the same model can hit a broken calibrator state instead of the original one.Suggested fix
if cal is None: continue - weight_quantizer._calibrator = cal - _run_and_load_max_stats(weight_quantizer, lambda q: q(weight)) - if hasattr(cal, "reset"): - cal.reset() + original_calibrator = weight_quantizer._calibrator + weight_quantizer._calibrator = cal + try: + _run_and_load_max_stats(weight_quantizer, lambda q: q(weight)) + finally: + weight_quantizer._calibrator = original_calibrator + if hasattr(cal, "reset"): + cal.reset() pbar.update(1)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/quantization/model_calib.py` around lines 523 - 536, The loop temporarily replaces weight_quantizer._calibrator with a new MSE calibrator from _make_weight_mse_calibrator, runs _run_and_load_max_stats, then calls cal.reset(), but never restores the original calibrator; save the original (orig = weight_quantizer._calibrator) before assigning the temporary cal, run _run_and_load_max_stats with the temp calibrator, call cal.reset() if needed, and finally restore weight_quantizer._calibrator = orig so the original calibrator state is preserved for subsequent passes (references: weight_quantizer._calibrator, _make_weight_mse_calibrator, _run_and_load_max_stats, cal.reset()).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 523-536: The loop temporarily replaces
weight_quantizer._calibrator with a new MSE calibrator from
_make_weight_mse_calibrator, runs _run_and_load_max_stats, then calls
cal.reset(), but never restores the original calibrator; save the original (orig
= weight_quantizer._calibrator) before assigning the temporary cal, run
_run_and_load_max_stats with the temp calibrator, call cal.reset() if needed,
and finally restore weight_quantizer._calibrator = orig so the original
calibrator state is preserved for subsequent passes (references:
weight_quantizer._calibrator, _make_weight_mse_calibrator,
_run_and_load_max_stats, cal.reset()).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 38c07283-0685-4897-9033-c82711e6c56c
📒 Files selected for processing (5)
modelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.pytests/unit/torch/quantization/test_mse_calibrator.py
Auto-replying to realAsma review comments:
|
/claude review |
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
The actual diff goes well beyond the stated "sync on CUDA + remove _wait_for_cuda_amax" scope: it adds a new public-ish config field apply_mse_nvfp_static_only, removes promote_nvfp4_static_quantizers from utils/core_utils.py, restructures mse_calibrate, narrows _bootstrap_uncalibrated_weight_quantizers to NVFP4-static only (renamed accordingly), and changes load_calib_amax so that an existing _amax buffer is always replaced (inheriting calib_amax's dtype) instead of copy_'d in place (preserving the buffer's dtype). That last change is model-wide, not NVFP4-specific, and the PR body doesn't call it out — worth a human pass to confirm no downstream caller relies on the buffer dtype being preserved (e.g. an fp16 _amax registered earlier silently becomes fp32). Tests look reasonable: new fp32-amax preservation test, fast-path-cache one-shot test, and a static-only-skip test for the registry. Couple of smaller items also worth eyeballing:
_promote_nvfp4_static_quantizers_with_global_amax_syncis now called insidemax_calibrate, so the bootstrap (formerly only triggered frommse_calibrate) now runs for everymax_calibrateconsumer (gptq, awq postprocess, smoothquant, etc.). Intended? The_check_moe_calibration_completepath runs after promotion — fine since the check looks at_amaxon input quantizers too, but worth confirming nothing now warns where it didn't before.local_hessian_calibratelost its explicit_sync_grouped_weight_global_amax(model)call. It's now subsumed bymax_calibrate→ promotion path, but local_hessian also has its own per-module promotion later. Double-promotion is a no-op (covered byfrom_tensor_quantizer), just noting the redundancy._compute_candidate_amaxnow multiplies bytorch.ones_like(self._initial_amax, dtype=torch.float32)— theones_likeis purely for broadcasting and a plain(self._global_amax.to(torch.float32) * candidates).expand_as(self._initial_amax)(or simpler, just relying on broadcast) would be cheaper; minor.
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Findings: 0 CRITICAL, 0 IMPORTANT, 2 SUGGESTIONS
The refactor is internally consistent:
NVFP4MSECalibratorbecoming one-shot is fine — it's only constructed for weights in_make_weight_mse_calibrator, never for activations.- Bootstrap scope narrowing (
_bootstrap_uncalibrated_static_weight_quantizers) is intentional and reflected in the test rename; all in-tree MSE recipes are NVFP4-only. load_calib_amaxchange to preserve calibrator dtype on the buffer is covered by the newtest_load_calib_amax_preserves_fp32_result_dtyperegression test._promote_nvfp4_static_quantizers_with_global_amax_syncis idempotent (from_tensor_quantizeris a no-op for already-promoted modules) andpreprocess_linear_fusioncorrectly unifies groupedglobal_amaxafter promotion.
Two non-blocking suggestions left as inline comments:
- Likely-unnecessary
torch.cuda.synchronizein_run_reference_collect. load_calib_amaxcould keep the new-buffer branch going throughregister_bufferfor symmetry withamax.setter.
Regarding #1536 (review): leaving this unchanged per branch-owner review. The current |
Regarding the NVFP4 reference MSE sync comment: updated the source comment to clarify that the CPU-GPU sync is there to prevent reference MSE calibration for another weight from running in parallel. The comment now also documents the retained |
Correction to my previous note about the NVFP4 reference MSE sync comment: I updated the memory wording to express |
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Regarding #1536 (review): Thanks for the careful pass. The broader scope is intentional: the fp32 amax preservation, static-NVFP4-only MSE gate, max-calibration promotion, and grouped global-amax sync are tied together so static NVFP4 weights are finalized consistently before downstream calibrators/export. I agree the PR body should make that scope clearer, especially the |
|
Is fp32 MSE scale preserved after save/restore? |
|
/claude review |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unit/torch/quantization/test_mse_calibrator.py (1)
560-563:⚠️ Potential issue | 🟠 Major | ⚡ Quick winMove helper imports to module scope in tests.
These imports are inside a test helper without a circular/optional/heavy-import justification, which violates the test conventions and can defer import failures to runtime.
Proposed fix
import torch +import modelopt.torch.quantization as mtq from modelopt.torch.quantization import calib from modelopt.torch.quantization.config import QuantizerAttributeConfig from modelopt.torch.quantization.model_calib import ( + mse_calibrate, _make_weight_mse_calibrator, _promote_nvfp4_static_quantizers_with_global_amax_sync, ) from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer +from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend from modelopt.torch.quantization.utils import enable_fake_quant @@ def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True): """Quantize a small Linear with the given backend and run mse_calibrate.""" - import modelopt.torch.quantization as mtq - from modelopt.torch.quantization.model_calib import mse_calibrate - from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend - register_quant_backend(backend_name, lambda x, tq: x)As per coding guidelines, “tests/**/*.py: Imports inside functions or test methods without explicit justification... Imports belong at the top of the file...”.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/quantization/test_mse_calibrator.py` around lines 560 - 563, The test currently performs local imports of modelopt.torch.quantization as mtq, mse_calibrate (from modelopt.torch.quantization.model_calib) and register_quant_backend (from modelopt.torch.quantization.nn.modules.tensor_quantizer) inside a helper; move these import statements to module scope at the top of the test file and remove the in-function imports so import errors surface at collection time and follow the test import convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 560-563: The test currently performs local imports of
modelopt.torch.quantization as mtq, mse_calibrate (from
modelopt.torch.quantization.model_calib) and register_quant_backend (from
modelopt.torch.quantization.nn.modules.tensor_quantizer) inside a helper; move
these import statements to module scope at the top of the test file and remove
the in-function imports so import errors surface at collection time and follow
the test import convention.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: d9b4f24a-c6cb-42a3-84e5-d1c8e44159db
📒 Files selected for processing (3)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/model_calib.pytests/unit/torch/quantization/test_mse_calibrator.py
There was a problem hiding this comment.
Claude review summary
Findings: CRITICAL: 0, IMPORTANT: 1, SUGGESTION: 5
Most impactful
- [IMPORTANT Compatibility]
_bootstrap_uncalibrated_static_weight_quantizerswas narrowed to NVFP4-static only and the unconditional bootstrap step at the top ofmse_calibratewas removed. Non-NVFP4 MoE workloads (e.g. INT8/FP8 weight quantization) with dead experts now leave those quantizers with_amax = Noneaftermse_calibrate—_make_weight_mse_calibratorsilently skips them, and the renamed test no longer covers the non-NVFP4 case. This is a real regression for INT8/FP8 MoE MSE flows.
Suggestions
load_calib_amaxshape check usesassert(stripped under-O) and replaces the buffer instead of in-place copy; both worth a comment /raise.promote_nvfp4_static_quantizerswas importable frommodelopt.torch.quantization.utils; consider a deprecation shim or a CHANGELOG note since the PR claims backward compatible.- Stale bootstrap warning text (no longer mentions "static NVFP4").
lambda q: q(weight)captures the loop variable; safe today but cheap to harden aslambda q, w=weight: q(w).- PR description advertises an
apply_mse_nvfp_static_onlyfield onMseCalibConfigthat isn't present in the diff — please reconcile.
Risk
The refactor itself (one-shot caching, fp32 amax preservation, centralized max-stat collection, NVFP4 promotion inside max_calibrate) is a clean cleanup with good test coverage for the NVFP4 path. The main risk is the silent loss of dead-expert bootstrap for non-NVFP4 MoE configurations going through mse_calibrate; addressing that (or confirming it's intended and out of scope) would clear the blocker.
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
A lot of progress on this re-review — the fp32-amax preservation, test-import scope, FP8-sweep config simplification (collapsing apply_mse_nvfp_static_only into fp8_scale_sweep), and the load_calib_amax simplification have all landed. Below are the points still worth a human pass before approval; ones the author has responded to are marked with 💬.
-
💬 Non-NVFP4 dead-expert MoE MSE regression —
_bootstrap_uncalibrated_weight_quantizerswas narrowed to_bootstrap_uncalibrated_static_weight_quantizers(NVFP4-only) and the unconditional bootstrap step at the top ofmse_calibratewas removed. Author's response (to cjluo) confirms scope expansion is intentional and tied to NVFP4 finalization, but doesn't directly address what happens to INT8/FP8 MoE configs with dead experts undermse: those weight quantizers will now stay at_amax=None,_make_weight_mse_calibratorsilently returnsNone, and the renamedtest_max_calibrate_populates_dead_static_nvfp4_expert_quantizersno longer exercises the non-NVFP4 path. Worth confirming this is intended (and not an unrelated MoE recipe regression). -
💬
load_calib_amaxbuffer identity change is model-wide — cjluo flagged that replacingself._buffers["_amax"]instead ofself._amax.data.copy_()changes the identity of the buffer for every quantizer, not just NVFP4. Author kept the new behavior to preserve fp32 dtype and added a regression test for the dtype, but the identity change for non-NVFP4 callers (anyone holding a stale_amaxreference acrossload_calib_amax) hasn't been audited — the PR body still doesn't call this out as a behavior change. A second pair of eyes on downstream callers would be reassuring. -
promote_nvfp4_static_quantizersremoved fromutils/core_utils.pywith no deprecation shim — the PR self-marks "backward compatible", butfrom modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers(used by at leasttests/gpu/torch/quantization/test_gptq.pybefore this PR) will nowImportErrorfor any external caller. Either keep a thin deprecation shim that calls_promote_nvfp4_static_quantizers_with_global_amax_sync, or note the removal in CHANGELOG and uncheck the back-compat box. -
In-function imports in test helper —
tests/unit/torch/quantization/test_mse_calibrator.py::TestRegisterFP8SweepCalibrator._quantize_and_calibratestill doesimport modelopt.torch.quantization as mtq/mse_calibrate/register_quant_backendinside the method. Earlier in-function imports in the same file and intest_nvfp4_fp8_sweep_kernel.pywere moved to module scope; this one was missed. Style-only, not a blocker. -
Minor: the
lambda q: q(weight)inmse_calibratecaptures the loopweight; safe today because_run_and_load_max_statsinvokes synchronously, butlambda q, w=weight: q(w)would harden it against any future deferral. Thegetattr(self, "_best_amax_fast", None)inNVFP4MSECalibrator.compute_amaxis also redundant given__init__sets the attribute — a plainif self._best_amax_fast is Nonewould do.
Additional comments (outside the PR diff):
tests/unit/torch/quantization/test_mse_calibrator.py:562— > Bot comment.
Style nit: mtq, mse_calibrate, and register_quant_backend are still imported inside this method. Earlier in-function imports in this file and in test_nvfp4_fp8_sweep_kernel.py were moved to module scope in e980773b; this helper looks like it was missed. Not a blocker.
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Pushed Addressed the bot-tagged feedback from the latest review:
Validation before push:
|
|
/claude review |
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
/claude review |
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Reviewed the NVFP4 MSE calibration refactor end-to-end (one-shot collect/cache, max-stat helper centralization, static NVFP4 promotion moved into max_calibrate, fp32-preserving load_calib_amax). Algorithm flow is correct: NVFP4MSECalibrator caches the final per-block amax in fp32 after either the Triton fast path or the reference 126-step sweep; multi-collect-after-final is now an explicit RuntimeError. The mse_calibrate dispatch (registry → ModelOpt static NVFP4 → MseCalibrator fallback) lines up with the docstring, and the local_hessian_calibrate path still calls collect once per weight so the new one-shot semantics don't regress it.
Posted three SUGGESTION-level comments only (none blocking):
- load_calib_amax docstring still says "updates" but now replaces the buffer; direct
_buffers["_amax"]access is inconsistent with the delattr+register_buffer pattern in export/quant_utils.py. - _compute_candidate_amax has a redundant ones_like multiplication that can be trimmed.
- _uses_modelopt_fp8_weight_scales relies on an implicit "max_calibrate already promoted" precondition before reading weight_quantizer.global_amax — making that explicit would harden the helper.
Risk assessment: low. Core changes are well-tested (fp32 dtype preservation, one-shot reset semantics, dead-expert static-NVFP4 bootstrap, grouped global-amax sync) and the behavior change is intentional and documented in the description.
| if hasattr(self, "_amax"): | ||
| if self._amax.shape != calib_amax.shape: | ||
| raise ValueError("Changing shape when loading calibration amax is not allowed.") | ||
| device = self._amax.device if hasattr(self, "_amax") else calib_amax.device | ||
| # Replace instead of copy_ so the stored amax keeps the calibrator result dtype. | ||
| self._buffers["_amax"] = calib_amax.clone().detach().to(device=device) |
There was a problem hiding this comment.
[SUGGESTION] Two small notes on this dtype-preserving rewrite:
- The docstring above (line 611) still says "Updates the amax buffer", but the new behavior replaces the buffer (so any external code holding
q._amaxdirectly will keep a stale tensor reference). Worth a one-line clarification. - Direct
self._buffers["_amax"] = ...bypassesregister_buffer, which is inconsistent with thedelattr+register_bufferpattern inmodelopt/torch/export/quant_utils.py:_set_amax_from_tensor. Functionally equivalent in PyTorch today, but if you want to keep the codebase using a single replacement idiom, callingdelattr(self, "_amax")(when it exists) and thenself.register_buffer("_amax", calib_amax.clone().detach().to(device=device))would match the existing pattern.
| initial_amax = weight_quantizer._amax.clone().detach() | ||
| axis = weight_quantizer._calibrator._axis | ||
| quant_func = partial(_mse_quant_func, quantizer=weight_quantizer) | ||
|
|
||
| if fp8_scale_sweep: | ||
| backend: str | None = getattr(weight_quantizer, "backend", None) | ||
| backend_factory = ( | ||
| _FP8_SWEEP_CALIBRATOR_REGISTRY.get(backend) if backend is not None else None | ||
| ) | ||
| if backend is not None and backend_factory is not None: | ||
| return backend_factory(initial_amax, axis, quant_func) | ||
| elif _uses_modelopt_fp8_weight_scales(weight_quantizer): | ||
| return NVFP4MSECalibrator( | ||
| amax=initial_amax, | ||
| axis=axis, | ||
| global_amax=weight_quantizer.global_amax, | ||
| quant_func=quant_func, | ||
| ) |
There was a problem hiding this comment.
[SUGGESTION] _uses_modelopt_fp8_weight_scales only checks is_nvfp4_static — it doesn't require the quantizer to already be promoted to NVFP4StaticQuantizer. Reading weight_quantizer.global_amax here will AttributeError for a base TensorQuantizer that happens to be is_nvfp4_static but hasn't been promoted yet.
In the current call graph this is safe because mse_calibrate runs max_calibrate first and max_calibrate now promotes via _promote_nvfp4_static_quantizers_with_global_amax_sync. But the precondition is implicit. Either an assert isinstance(weight_quantizer, NVFP4StaticQuantizer) here, or having _uses_modelopt_fp8_weight_scales also check isinstance(weight_quantizer, NVFP4StaticQuantizer), would make the contract explicit and produce a friendlier error if the helper is reused elsewhere.
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review summary: most prior concerns are addressed, but one model-wide behavior change still merits human sign-off.
Addressed since the previous round:
- 💬 Non-NVFP4 dead-expert MoE bootstrap (cjluo) —
e3da774awidens_bootstrap_uncalibrated_weight_quantizersback to all quantizers and runs it insidemax_calibrate(defaultbootstrap_weight_quantizers=True), with a newtest_max_calibrate_bootstraps_non_nvfp4_dead_weight_quantizertest. Looks resolved. - 💬
promote_nvfp4_static_quantizersremoval (cjluo) —e3da774arestores it inutils/core_utils.pyas the public helper;from modelopt.torch.quantization.utils import promote_nvfp4_static_quantizersworks again via the star re-export. - 💬
load_calib_amaxassert→raise ValueErrorand new-buffer comment (claude) — applied. - 💬 In-method test imports (CodeRabbit) — moved to module scope in
test_mse_calibrator.pyandtest_nvfp4_fp8_sweep_kernel.py. - 💬
fp8_scale_sweep+apply_mse_nvfp_static_onlytwo-flag confusion (Fridah-nv) —419657d2collapses to a singlefp8_scale_sweepswitch, with dispatch handled in_make_weight_mse_calibratorvia_uses_modelopt_fp8_weight_scalesand the registered backend factory. - Lambda capture of loop
weightinmse_calibrate(claude) — implicitly fixed by switching topartial(_collect_weight_stats, weight=weight).
Still warrants a human pass before approval:
- 💬
load_calib_amaxnow replacesself._buffers["_amax"]instead ofcopy_'ing in-place (cjluo). Author kept the new behavior to preserve fp32 dtype and added an inline comment + a regression test, but this is a model-wide identity change for every quantizer (not just NVFP4). Any caller holding a staleq._amaxreference acrossload_calib_amax()will now silently see the pre-load tensor. The PR body still doesn't call this out as a behavior change. Quick audit of in-tree callers ofload_calib_amax(export paths, calibration restore,amax.setter) would be reassuring before sign-off. _run_and_load_max_statshas notry/finallyaroundenable_stats_collection/finish_stats_collection(CodeRabbit). Author explicitly declined; reasonable given the linear lifecycle, but worth a final sanity check that no new caller (e.g.max_calibrateconsumers) can leave quantizers stuck in_if_calib=Trueafter a forward exception.
No design-review concerns: this PR refactors existing calibration code rather than introducing a new subsystem. Test coverage looks solid (fp32-amax preservation, one-shot collect/reset, dead-expert bootstrap for NVFP4 and non-NVFP4, registered FP8-sweep dispatch, grouped global-amax sync).
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Addressing the latest modelopt review note about the model-wide
The fp32 preservation requirement is now isolated to Focused validation passed: |
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
What does this PR do?
Type of change: Bug fix
Refines static NVFP4 MSE calibration so FP8-scale sweep calibration, max-calibration finalization, and export-facing static quantizer state stay consistent.
Main changes:
fp8_scale_sweep. Registered custom backends can provide an FP8 sweep calibrator, ModelOpt static NVFP4 weights useNVFP4MSECalibrator, and other weight quantizers keep the regular multiplier MSE search.NVFP4MSECalibratornow caches the final per-block amax astorch.float32for both the Triton fast path and reference sweep, andTensorQuantizer.load_calib_amax()preserves the loaded amax dtype after validating shape.NVFP4MSECalibratorone-shot between resets.collect()computes and caches the final per-block amax immediately, clears the large reference loss accumulator, andcompute_amax()returns the cached result.model_calib.py, and move static NVFP4 finalization intomax_calibrate(). Finalization bootstraps missing weight amax values for skipped/dead experts, promotes eligible quantizers toNVFP4StaticQuantizer, and synchronizes grouped Q/K/V and gate/up global amax values.promote_nvfp4_static_quantizers()as the public promotion helper, with the model-calibration path delegating to it before grouped global-amax sync.Usage
For static NVFP4 weight-MSE calibration, use the NVFP4 W4A4 FP8-sweep preset or set the MSE algorithm field directly:
fp8_scale_sweepapplies to ModelOpt static NVFP4 weight quantizers and registered custom backends with FP8 sweep support. Other weight quantizers use the regular multiplier search.Testing
Focused validation used while developing this PR included:
GitHub CI is also running the standard unit, regression, GPU, example, docs, DCO, and code-quality workflows for the PR.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: N/AAdditional Information
N/A