Skip to content

fix(losses): register buffers in GlobalMutualInformationLoss#8872

Open
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer
Open

fix(losses): register buffers in GlobalMutualInformationLoss#8872
AlexanderSanin wants to merge 2 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/global-mutual-information-register-buffer

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

Summary

  • GlobalMutualInformationLoss stored preterm and bin_centers as plain tensor attributes when kernel_type="gaussian", so calling loss.to("cuda") or loss.cuda() did not move them to the target device
  • Replace the plain assignments with register_buffer(..., persistent=False), consistent with the pattern already applied to LocalNormalizedCrossCorrelationLoss in fix: use register_buffer for kernel and kernel_vol in LocalNormalizedCrossCorrelationLoss #8818
  • The .to(img) calls in parzen_windowing_gaussian are retained for dtype coercion (e.g. float16 inference)

Test plan

  • python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -v — all existing tests still pass
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_registers_bufferspreterm and bin_centers are in _buffers and have requires_grad=False
  • TestGlobalMutualInformationLossBuffers::test_bspline_kernel_has_no_gaussian_buffers — b-spline mode is unaffected
  • TestGlobalMutualInformationLossBuffers::test_gaussian_kernel_forward_correct — forward pass returns a scalar loss

Closes #8819

When kernel_type="gaussian", `preterm` and `bin_centers` were stored
as plain tensor attributes via simple assignment. This means they are
not registered in PyTorch's module buffer system, so calling
`loss.to("cuda")` or `loss.cuda()` does not move these tensors to the
target device. Each forward pass had to call `.to(img)` to patch the
device mismatch at runtime, which is both redundant and misleading.

Use `register_buffer(..., persistent=False)` so that both tensors are
properly tracked by the module and automatically move with `.to()` /
`.cuda()` / `.cpu()` calls, consistent with the pattern already used
by `LocalNormalizedCrossCorrelationLoss`.

The `.to(img)` calls in `parzen_windowing_gaussian` are retained for
dtype coercion (e.g. float16 inference).

Adds `TestGlobalMutualInformationLossBuffers` to verify buffer
registration and that b-spline mode does not create gaussian buffers.

Closes Project-MONAI#8819

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hey @ericspod @aymuos15. Could you, please, have a look at this?

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 25, 2026

📝 Walkthrough

Walkthrough

GlobalMutualInformationLoss now registers preterm and bin_centers as non-persistent buffers before the kernel-type dispatch and populates them when kernel_type == "gaussian". Tests were added to verify buffer registration and properties for the gaussian kernel, absence for b-spline, that a gaussian forward returns a scalar tensor, and that the gaussian buffers move with the module to CUDA.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately summarizes the main change: registering buffers in GlobalMutualInformationLoss to fix device placement issues.
Description check ✅ Passed Description covers the problem, solution, and test plan clearly. However, the template checkbox section is incomplete—only one box is checked and others are left unchecked/unmarked.
Linked Issues check ✅ Passed Changes directly address #8819 requirements: register_buffer replaces plain attributes for preterm and bin_centers, persistent=False used, device movement works, existing tests pass, new tests verify buffer registration and device behavior.
Out of Scope Changes check ✅ Passed All changes are scoped to GlobalMutualInformationLoss buffer registration and corresponding test coverage; no unrelated modifications detected.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (3)

158-161: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_bspline_kernel_has_no_gaussian_buffers(self):
+    """Verify b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 158 - 161, The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.

163-168: ⚡ Quick win

Add docstring per coding guidelines.

📝 Suggested docstring
 def test_gaussian_kernel_forward_correct(self):
+    """Verify gaussian kernel forward pass returns scalar loss tensor."""
     pred = torch.rand(2, 1, 8, 8, dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 163 - 168, Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.

149-156: ⚡ Quick win

Add docstring per coding guidelines.

Docstrings required for all test methods describing purpose and expectations. As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

📝 Suggested docstring
 def test_gaussian_kernel_registers_buffers(self):
+    """Verify gaussian kernel registers preterm and bin_centers as non-trainable buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 156, Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
monai/losses/image_dissimilarity.py (1)

236-237: 💤 Low value

Type annotations declared unconditionally but attributes are conditionally assigned.

These annotations are defined outside the gaussian conditional block, but the actual attributes are only created when kernel_type == "gaussian". While runtime behavior is correct (attributes only accessed in gaussian path), static type checkers may flag potential AttributeError for b-spline mode.

Consider either:

  • Moving annotations inside the conditional, or
  • Initializing to None and using Optional[torch.Tensor] type
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/image_dissimilarity.py` around lines 236 - 237, The attributes
self.preterm and self.bin_centers are only created when kernel_type ==
"gaussian" but currently annotated unconditionally; update their declarations to
reflect conditional creation by typing them as Optional[torch.Tensor] and
initialize them to None in the non-gaussian branch (or before the conditional)
so static type checkers know they may be absent, and ensure any gaussian-only
use sites (e.g., inside the gaussian branch) treat them as non-None; reference
the attributes self.preterm, self.bin_centers and the kernel_type == "gaussian"
conditional when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 148-169: Add a test that verifies Gaussian kernel buffers actually
move with the module: instantiate
GlobalMutualInformationLoss(kernel_type="gaussian"), if
torch.cuda.is_available() call loss_cuda = loss.to("cuda") (or loss.cuda()),
then assert loss_cuda.preterm.device.type == "cuda" and
loss_cuda.bin_centers.device.type == "cuda", create CUDA tensors for pred and
target and run result = loss_cuda(pred, target) and assert result.device.type ==
"cuda"; reference the GlobalMutualInformationLoss class and its buffers preterm
and bin_centers and add this as a new test method (e.g.,
test_gaussian_buffers_move_with_module) alongside the existing tests.

---

Nitpick comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 236-237: The attributes self.preterm and self.bin_centers are only
created when kernel_type == "gaussian" but currently annotated unconditionally;
update their declarations to reflect conditional creation by typing them as
Optional[torch.Tensor] and initialize them to None in the non-gaussian branch
(or before the conditional) so static type checkers know they may be absent, and
ensure any gaussian-only use sites (e.g., inside the gaussian branch) treat them
as non-None; reference the attributes self.preterm, self.bin_centers and the
kernel_type == "gaussian" conditional when making the change.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 158-161: The test function
test_bspline_kernel_has_no_gaussian_buffers is missing a docstring; add a short
descriptive docstring at the top of the function explaining that it verifies
GlobalMutualInformationLoss(kernel_type="b-spline") does not populate
Gaussian-specific buffers (specifically asserting "preterm" and "bin_centers"
are not in loss._buffers). Keep it concise and follow existing test docstring
style.
- Around line 163-168: Add a docstring to the unit test function
test_gaussian_kernel_forward_correct that briefly describes what the test
verifies (e.g., that GlobalMutualInformationLoss with kernel_type="gaussian"
returns a scalar tensor and preserves shape), placing it directly under the def
line in that function; reference the function name
test_gaussian_kernel_forward_correct and the class/constructor
GlobalMutualInformationLoss(kernel_type="gaussian") so reviewers can locate and
confirm the new docstring.
- Around line 149-156: Add a Google-style docstring to the test method
test_gaussian_kernel_registers_buffers describing what is being tested (that
GlobalMutualInformationLoss with kernel_type="gaussian" registers preterm and
bin_centers as non-trainable buffers, that they move with .to(), and that
bin_centers has ndim == 3), including a short "Args" if needed and an "Expected"
or "Raises" note for the assertions; update the docstring inside the test
function definition (test_gaussian_kernel_registers_buffers) so it clearly
states the purpose and the expected conditions checked by the assertions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 22b238ee-90d5-4b39-9f22-3df62dfea05d

📥 Commits

Reviewing files that changed from the base of the PR and between 0a8d945 and f20d3f6.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

…uffer fix

Use register_buffer("preterm", None) / register_buffer("bin_centers", None)
unconditionally so that both buffers are always present in _buffers (with None
for b-spline). This avoids a KeyError that occurred when plain instance
attribute assignment conflicted with a subsequent register_buffer call.

Also add docstrings to the new test methods and a device-movement test that
verifies buffers follow the module when .cuda() is called.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (2)

149-185: ⚡ Quick win

Use full Google-style docstrings for new test methods.

Current one-line docstrings don’t include the required sections (Args, Returns, Raises) from the repo guideline.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 185, The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.

149-163: ⚡ Quick win

Assert non-persistent buffer contract explicitly.

Please also verify preterm and bin_centers are excluded from state_dict() to lock in persistent=False behavior.

Proposed test additions
 def test_gaussian_kernel_registers_buffers(self):
     """preterm and bin_centers are registered as non-persistent buffers for gaussian kernel."""
     loss = GlobalMutualInformationLoss(kernel_type="gaussian")
     self.assertIn("preterm", loss._buffers)
     self.assertIn("bin_centers", loss._buffers)
     self.assertFalse(loss.preterm.requires_grad)
     self.assertFalse(loss.bin_centers.requires_grad)
     self.assertEqual(loss.bin_centers.ndim, 3)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)

 def test_bspline_kernel_has_no_gaussian_buffers(self):
     """b-spline kernel does not register gaussian-specific buffers."""
     loss = GlobalMutualInformationLoss(kernel_type="b-spline")
     self.assertIsNone(loss.preterm)
     self.assertIsNone(loss.bin_centers)
+    state = loss.state_dict()
+    self.assertNotIn("preterm", state)
+    self.assertNotIn("bin_centers", state)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 149 - 163, Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 149-185: The tests use one-line docstrings; update each test's
docstring (test_gaussian_kernel_registers_buffers,
test_bspline_kernel_has_no_gaussian_buffers,
test_gaussian_kernel_forward_correct, test_gaussian_buffers_move_with_module) to
full Google-style docstrings that include a short summary plus Args (describe
pred/target shapes or when the test constructs the loss), Returns (what the test
asserts, e.g., None or scalar loss), and Raises (any expected exceptions, if
none state "None"); keep the existing descriptive text as the summary and add
the three sections to meet the repo guideline.
- Around line 149-163: Update the tests for GlobalMutualInformationLoss to
assert the non-persistent buffer contract by checking that gaussian-specific
buffers do not appear in the module state dict: in
test_gaussian_kernel_registers_buffers (for kernel_type="gaussian") after
asserting preterm and bin_centers exist and have correct properties, also call
loss.state_dict() and assert "preterm" and "bin_centers" are not keys;
similarly, in test_bspline_kernel_has_no_gaussian_buffers (for
kernel_type="b-spline") confirm state_dict() also does not contain those keys
(and that loss.preterm and loss.bin_centers remain None) so persistent=False
behavior is enforced.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 13a2592f-5cd6-4641-a30a-def8e3d5df80

📥 Commits

Reviewing files that changed from the base of the PR and between f20d3f6 and 20c702a.

📒 Files selected for processing (2)
  • monai/losses/image_dissimilarity.py
  • tests/losses/image_dissimilarity/test_global_mutual_information_loss.py

@aymuos15
Copy link
Copy Markdown
Contributor

Happy to go through this. Any idea why the CI is failing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] LocalNormalizedCrossCorrelationLoss: kernel not registered as buffer — silent gradient tracking + wrong device placement

2 participants