Register GlobalMutualInformationLoss bin_centers as buffer#8869
Register GlobalMutualInformationLoss bin_centers as buffer#8869ugbotueferhire wants to merge 1 commit into
Conversation
Signed-off-by: Anthonyushie <anthonytwan75official@gmail.com>
📝 WalkthroughWalkthrough
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/losses/image_dissimilarity.py (1)
202-227: 🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick winDocument the
bin_centersbuffer.The
__init__docstring should mentionbin_centersas registered module state. Add a note explaining it holds bin centers for Gaussian kernel Parzen windowing (shape:(1, 1, num_bins)for broadcasting). As per coding guidelines, docstrings should describe all variables and state.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/losses/image_dissimilarity.py` around lines 202 - 227, Update the __init__ docstring to document the registered buffer bin_centers: state that bin_centers is stored as a module buffer (registered via self.register_buffer) containing the Gaussian-Parzen window bin centers used for intensity Parzen windowing, its purpose (used by the Gaussian kernel to compute soft histogram/probability per bin), and its expected shape for broadcasting (1, 1, num_bins). Mention it is created in the constructor alongside other args (kernel_type, num_bins, sigma_ratio, etc.) so readers know it is part of the module state.
🧹 Nitpick comments (1)
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py (1)
119-127: ⚡ Quick winAdd docstring and consider shape verification.
The test method lacks a docstring explaining what aspects of buffer registration it verifies. Also consider checking
bin_centers.shapeequals(1, 1, 16)to confirm the shape expansion. As per coding guidelines, docstrings should be present for all definitions.📋 Suggested improvements
def test_gaussian_bin_centers_registered_buffer(self): + """Verify bin_centers is a registered buffer with correct dtype/device behavior.""" loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) self.assertIn("bin_centers", dict(loss.named_buffers())) self.assertFalse(loss.bin_centers.requires_grad) + self.assertEqual(loss.bin_centers.shape, (1, 1, 16)) loss = loss.to(dtype=torch.float64)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py` around lines 119 - 127, Add a docstring to the test method test_gaussian_bin_centers_registered_buffer describing that it verifies buffer registration, dtype preservation on .to(), and expected shape; then add an assertion that loss.bin_centers.shape == (1, 1, 16) to verify the expanded shape, and keep the existing checks for presence in named_buffers(), requires_grad False, and dtype after loss.to(dtype=torch.float64).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 236-239: The type annotation for self.bin_centers is declared
unconditionally but the buffer is only created when self.kernel_type ==
"gaussian", so move the annotation inside that conditional: remove the top-level
"self.bin_centers: torch.Tensor" and add the annotation immediately before or
alongside the register_buffer call within the if block that sets self.preterm
and calls self.register_buffer("bin_centers", ...); this ensures
self.bin_centers is only typed/defined when the gaussian branch executes (refer
to the conditional using self.kernel_type, the attribute self.preterm, and the
register_buffer call for locating the code).
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 119-127: Update the test_gaussian_bin_centers_registered_buffer to
also assert that the registered buffer moves devices when the module is moved:
create a CUDA device if torch.cuda.is_available(), call loss = loss.to(device),
then assert loss.bin_centers.device == device; keep existing dtype/assertions
and use the same GlobalMutualInformationLoss and bin_centers symbols so the
check validates device placement after .to(device).
---
Outside diff comments:
In `@monai/losses/image_dissimilarity.py`:
- Around line 202-227: Update the __init__ docstring to document the registered
buffer bin_centers: state that bin_centers is stored as a module buffer
(registered via self.register_buffer) containing the Gaussian-Parzen window bin
centers used for intensity Parzen windowing, its purpose (used by the Gaussian
kernel to compute soft histogram/probability per bin), and its expected shape
for broadcasting (1, 1, num_bins). Mention it is created in the constructor
alongside other args (kernel_type, num_bins, sigma_ratio, etc.) so readers know
it is part of the module state.
---
Nitpick comments:
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`:
- Around line 119-127: Add a docstring to the test method
test_gaussian_bin_centers_registered_buffer describing that it verifies buffer
registration, dtype preservation on .to(), and expected shape; then add an
assertion that loss.bin_centers.shape == (1, 1, 16) to verify the expanded
shape, and keep the existing checks for presence in named_buffers(),
requires_grad False, and dtype after loss.to(dtype=torch.float64).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 4e9d79f5-6f44-4caa-b664-3f5dcd5cf326
📒 Files selected for processing (2)
monai/losses/image_dissimilarity.pytests/losses/image_dissimilarity/test_global_mutual_information_loss.py
| self.bin_centers: torch.Tensor | ||
| if self.kernel_type == "gaussian": | ||
| self.preterm = 1 / (2 * sigma**2) | ||
| self.bin_centers = bin_centers[None, None, ...] | ||
| self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) |
There was a problem hiding this comment.
Move type annotation inside the conditional block.
The type annotation declares bin_centers unconditionally, but the buffer is only registered for Gaussian kernels. For b-spline mode, bin_centers won't exist. Move the annotation to line 238 (inside the if block) to accurately reflect the conditional initialization.
📝 Proposed fix
- self.bin_centers: torch.Tensor
if self.kernel_type == "gaussian":
+ self.bin_centers: torch.Tensor
self.preterm = 1 / (2 * sigma**2)
self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.bin_centers: torch.Tensor | |
| if self.kernel_type == "gaussian": | |
| self.preterm = 1 / (2 * sigma**2) | |
| self.bin_centers = bin_centers[None, None, ...] | |
| self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) | |
| if self.kernel_type == "gaussian": | |
| self.bin_centers: torch.Tensor | |
| self.preterm = 1 / (2 * sigma**2) | |
| self.register_buffer("bin_centers", bin_centers[None, None, ...], persistent=False) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@monai/losses/image_dissimilarity.py` around lines 236 - 239, The type
annotation for self.bin_centers is declared unconditionally but the buffer is
only created when self.kernel_type == "gaussian", so move the annotation inside
that conditional: remove the top-level "self.bin_centers: torch.Tensor" and add
the annotation immediately before or alongside the register_buffer call within
the if block that sets self.preterm and calls
self.register_buffer("bin_centers", ...); this ensures self.bin_centers is only
typed/defined when the gaussian branch executes (refer to the conditional using
self.kernel_type, the attribute self.preterm, and the register_buffer call for
locating the code).
| def test_gaussian_bin_centers_registered_buffer(self): | ||
| loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) | ||
|
|
||
| self.assertIn("bin_centers", dict(loss.named_buffers())) | ||
| self.assertFalse(loss.bin_centers.requires_grad) | ||
|
|
||
| loss = loss.to(dtype=torch.float64) | ||
| self.assertEqual(loss.bin_centers.dtype, torch.float64) | ||
|
|
There was a problem hiding this comment.
Add device movement verification.
The PR fixes device placement (issue #8866's primary concern), but the test doesn't verify .to(device) moves bin_centers. Add a check that confirms buffer device changes when the module is moved to CUDA (if available).
🧪 Suggested test addition
loss = loss.to(dtype=torch.float64)
self.assertEqual(loss.bin_centers.dtype, torch.float64)
+
+ if torch.cuda.is_available():
+ loss_cuda = loss.to(device='cuda')
+ self.assertEqual(loss_cuda.bin_centers.device.type, 'cuda')📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_gaussian_bin_centers_registered_buffer(self): | |
| loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) | |
| self.assertIn("bin_centers", dict(loss.named_buffers())) | |
| self.assertFalse(loss.bin_centers.requires_grad) | |
| loss = loss.to(dtype=torch.float64) | |
| self.assertEqual(loss.bin_centers.dtype, torch.float64) | |
| def test_gaussian_bin_centers_registered_buffer(self): | |
| loss = GlobalMutualInformationLoss(kernel_type="gaussian", num_bins=16) | |
| self.assertIn("bin_centers", dict(loss.named_buffers())) | |
| self.assertFalse(loss.bin_centers.requires_grad) | |
| loss = loss.to(dtype=torch.float64) | |
| self.assertEqual(loss.bin_centers.dtype, torch.float64) | |
| if torch.cuda.is_available(): | |
| loss_cuda = loss.to(device='cuda') | |
| self.assertEqual(loss_cuda.bin_centers.device.type, 'cuda') |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/losses/image_dissimilarity/test_global_mutual_information_loss.py`
around lines 119 - 127, Update the test_gaussian_bin_centers_registered_buffer
to also assert that the registered buffer moves devices when the module is
moved: create a CUDA device if torch.cuda.is_available(), call loss =
loss.to(device), then assert loss.bin_centers.device == device; keep existing
dtype/assertions and use the same GlobalMutualInformationLoss and bin_centers
symbols so the check validates device placement after .to(device).
Fixes #8866.
Description
This PR fixes a device placement bug in
GlobalMutualInformationLoss.__init__wherebin_centerswas assigned as a plain Python attribute instead of being registered as a buffer.bin_centersas a non-persistent buffer inimage_dissimilarity.py, ensuring thatGlobalMutualInformationLossnow follows normalnn.Modulebuffer semantics for.to(device)/dtypemoves and avoids silent gradient tracking.test_global_mutual_information_loss.pyto verify thatbin_centersis properly exposed throughnamed_buffers(), does not require gradients, and successfully changes dtype when the module is moved.Verification: Passed
python -m pytest tests/losses/image_dissimilarity/test_global_mutual_information_loss.py -q -k gaussian_bin_centers_registered_bufferand-k ill_opts.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.