diff --git a/compressai/latent_codecs/__init__.py b/compressai/latent_codecs/__init__.py index 82a41947..ceada0b1 100644 --- a/compressai/latent_codecs/__init__.py +++ b/compressai/latent_codecs/__init__.py @@ -29,7 +29,6 @@ from .base import LatentCodec from .channel_groups import ChannelGroupsLatentCodec -from .channel_slice import ChannelSliceLatentCodec from .checkerboard import CheckerboardLatentCodec from .entropy_bottleneck import EntropyBottleneckLatentCodec from .gain import GainHyperLatentCodec, GainHyperpriorLatentCodec @@ -41,7 +40,6 @@ __all__ = [ "LatentCodec", "ChannelGroupsLatentCodec", - "ChannelSliceLatentCodec", "CheckerboardLatentCodec", "EntropyBottleneckLatentCodec", "GainHyperLatentCodec", diff --git a/compressai/latent_codecs/channel_groups.py b/compressai/latent_codecs/channel_groups.py index dd8956c8..b493dc04 100644 --- a/compressai/latent_codecs/channel_groups.py +++ b/compressai/latent_codecs/channel_groups.py @@ -28,7 +28,7 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from itertools import accumulate -from typing import Any, Dict, List, Mapping, Tuple +from typing import Any, Dict, List, Mapping, Optional, Tuple import torch import torch.nn as nn @@ -74,12 +74,18 @@ def __init__( channel_context: Mapping[str, nn.Module], *, groups: List[int], + support_slices: Optional[List[List[int]]] = None, **kwargs, ): super().__init__() self._kwargs = kwargs self.groups = list(groups) self.groups_acc = list(accumulate(self.groups, initial=0)) + if support_slices is None: + support_slices = [range(k) for k in range(len(self.groups))] + assert len(support_slices) == len(self.groups) + assert all(all(0 <= j < k for j in s) for k, s in enumerate(support_slices)) + self.support_slices = [tuple(support_slice) for support_slice in support_slices] self.channel_context = nn.ModuleDict(channel_context) self.latent_codec = nn.ModuleDict(latent_codec) @@ -121,14 +127,14 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]: return { "strings": [s for ss in y_strings_groups for s in ss], - "shape": [y_out["shape"] for y_out in y_out_], + "shape": y.shape[1:], "y_hat": y_hat, } def decompress( self, strings: List[List[bytes]], - shape: List[Tuple[int, ...]], + shape: Tuple[int, ...], side_params: Tensor, **kwargs, ) -> Dict[str, Any]: @@ -137,15 +143,14 @@ def decompress( strings_per_group = len(strings) // len(self.groups) y_out_ = [{}] * len(self.groups) - y_shape = (sum(s[0] for s in shape), *shape[0][1:]) - y_hat = torch.zeros((n, *y_shape), device=side_params.device) + y_hat = torch.zeros((n, *shape), device=side_params.device) y_hat_ = y_hat.split(self.groups, dim=1) for k in range(len(self.groups)): params = self._get_ctx_params(k, side_params, y_hat_) y_out_[k] = self.latent_codec[f"y{k}"].decompress( strings[strings_per_group * k : strings_per_group * (k + 1)], - shape[k], + (self.groups[k], *shape[1:]), params, ) y_hat_[k][:] = y_out_[k]["y_hat"] @@ -165,5 +170,6 @@ def _get_ctx_params( ) -> Tensor: if k == 0: return side_params - ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*y_hat_[:k])) + support = [y_hat_[i] for i in self.support_slices[k]] + ch_ctx_params = self.channel_context[f"y{k}"](self.merge_y(*support)) return self.merge_params(ch_ctx_params, side_params) diff --git a/compressai/latent_codecs/channel_slice.py b/compressai/latent_codecs/channel_slice.py deleted file mode 100644 index 73e32d6e..00000000 --- a/compressai/latent_codecs/channel_slice.py +++ /dev/null @@ -1,269 +0,0 @@ -# Copyright (c) 2021-2025, InterDigital Communications, Inc -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted (subject to the limitations in the disclaimer -# below) provided that the following conditions are met: - -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# * Neither the name of InterDigital Communications, Inc nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. - -# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY -# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND -# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT -# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A -# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, -# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR -# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF -# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from __future__ import annotations - -from typing import Any, Dict, List, Optional, Sequence, Tuple - -import torch -import torch.nn as nn - -from torch import Tensor - -from compressai.ans import BufferedRansEncoder, RansDecoder -from compressai.entropy_models import GaussianConditional -from compressai.ops import quantize_ste -from compressai.registry import register_module - -from .base import LatentCodec - -__all__ = [ - "ChannelSliceLatentCodec", -] - - -@register_module("ChannelSliceLatentCodec") -class ChannelSliceLatentCodec(LatentCodec): - """Channel-conditional entropy model with separate scale/mean heads and LRP. - - Splits ``y`` into equal-sized slices along the channel axis. For each - slice ``k`` the previously decoded slices (truncated to - ``max_support_slices``) are concatenated with ``latent_means`` / - ``latent_scales`` and pushed through ``cc_mean_transforms[k]`` and - ``cc_scale_transforms[k]`` to obtain ``mu`` / ``scale``. After the - Gaussian conditional step, an optional latent residual prediction - (LRP) head refines ``y_hat``. - - This is the channel-autoregressive entropy model from [Minnen2020] - with the LRP refinement variant used in [Zhu2022] (STF / WACNN), - [He2022] (ELIC) and many follow-up papers (MLIC++, TCM, ...). - - [Minnen2020]: `"Channel-wise Autoregressive Entropy Models for - Learned Image Compression" `_, by - David Minnen and Saurabh Singh, ICIP 2020. - - [Zhu2022]: `"Transformer-based Transform Coding" - `_, by Yinhao Zhu, - Yang Yang and Taco Cohen, ICLR 2022. - """ - - cc_mean_transforms: nn.ModuleList - cc_scale_transforms: nn.ModuleList - lrp_transforms: nn.ModuleList - gaussian_conditional: GaussianConditional - - def __init__( - self, - cc_mean_transforms: nn.ModuleList, - cc_scale_transforms: nn.ModuleList, - lrp_transforms: Optional[nn.ModuleList] = None, - gaussian_conditional: Optional[GaussianConditional] = None, - mean_support_transforms: Optional[nn.ModuleList] = None, - scale_support_transforms: Optional[nn.ModuleList] = None, - *, - num_slices: Optional[int] = None, - max_support_slices: int = -1, - quantizer: str = "ste", - lrp_scale: float = 0.5, - **kwargs: Any, - ) -> None: - super().__init__() - self._kwargs = kwargs - - inferred_num_slices = len(cc_mean_transforms) - if num_slices is None: - num_slices = inferred_num_slices - if inferred_num_slices != num_slices: - raise ValueError( - "cc_mean_transforms must have num_slices entries " - f"(got {inferred_num_slices}, expected {num_slices})" - ) - if len(cc_scale_transforms) != num_slices: - raise ValueError("cc_scale_transforms must have num_slices entries") - if lrp_transforms is not None and len(lrp_transforms) != num_slices: - raise ValueError("lrp_transforms must have num_slices entries") - if ( - mean_support_transforms is not None - and len(mean_support_transforms) != num_slices - ): - raise ValueError("mean_support_transforms must have num_slices entries") - if ( - scale_support_transforms is not None - and len(scale_support_transforms) != num_slices - ): - raise ValueError("scale_support_transforms must have num_slices entries") - if quantizer not in ("ste", "noise"): - raise ValueError(f"unknown quantizer {quantizer!r}") - - self.num_slices = int(num_slices) - self.max_support_slices = int(max_support_slices) - self.quantizer = quantizer - self.lrp_scale = float(lrp_scale) - self.cc_mean_transforms = cc_mean_transforms - self.cc_scale_transforms = cc_scale_transforms - self.mean_support_transforms = mean_support_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.scale_support_transforms = scale_support_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.lrp_transforms = lrp_transforms or nn.ModuleList( - nn.Identity() for _ in range(num_slices) - ) - self.gaussian_conditional = gaussian_conditional or GaussianConditional(None) - - def _support_slices(self, y_hat_slices: Sequence[Tensor]) -> List[Tensor]: - if self.max_support_slices < 0: - return list(y_hat_slices) - return list(y_hat_slices[: self.max_support_slices]) - - def _slice_params( - self, - slice_index: int, - latent_means: Tensor, - latent_scales: Tensor, - y_hat_slices: Sequence[Tensor], - spatial_shape: Tuple[int, int], - ) -> Tuple[Tensor, Tensor, Tensor]: - support = self._support_slices(y_hat_slices) - mean_support = torch.cat([latent_means, *support], dim=1) - mean_support = self.mean_support_transforms[slice_index](mean_support) - mu = self.cc_mean_transforms[slice_index](mean_support) - mu = mu[:, :, : spatial_shape[0], : spatial_shape[1]] - scale_support = torch.cat([latent_scales, *support], dim=1) - scale_support = self.scale_support_transforms[slice_index](scale_support) - scale = self.cc_scale_transforms[slice_index](scale_support) - scale = scale[:, :, : spatial_shape[0], : spatial_shape[1]] - return mu, scale, mean_support - - def _apply_lrp( - self, slice_index: int, mean_support: Tensor, y_hat_slice: Tensor - ) -> Tensor: - lrp = self.lrp_transforms[slice_index]( - torch.cat([mean_support, y_hat_slice], dim=1) - ) - return y_hat_slice + self.lrp_scale * torch.tanh(lrp) - - def forward( - self, - y: Tensor, - latent_means: Tensor, - latent_scales: Tensor, - ) -> Dict[str, Any]: - spatial_shape = (y.shape[2], y.shape[3]) - y_hat_slices: List[Tensor] = [] - y_likelihoods_slices: List[Tensor] = [] - - for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape - ) - _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scale, means=mu) - if self.quantizer == "ste": - y_hat_slice = quantize_ste(y_slice - mu) + mu - else: - y_hat_slice = self.gaussian_conditional.quantize( - y_slice, "noise" if self.training else "dequantize", mu - ) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - y_likelihoods_slices.append(y_slice_likelihoods) - - return { - "y_hat": torch.cat(y_hat_slices, dim=1), - "likelihoods": {"y": torch.cat(y_likelihoods_slices, dim=1)}, - } - - def compress( - self, - y: Tensor, - latent_means: Tensor, - latent_scales: Tensor, - ) -> Dict[str, Any]: - spatial_shape = (y.shape[2], y.shape[3]) - cdf = self.gaussian_conditional.quantized_cdf.tolist() - cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() - offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() - encoder = BufferedRansEncoder() - symbols_list: List[int] = [] - indexes_list: List[int] = [] - y_hat_slices: List[Tensor] = [] - - for slice_index, y_slice in enumerate(y.chunk(self.num_slices, dim=1)): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, spatial_shape - ) - indexes = self.gaussian_conditional.build_indexes(scale) - y_q_slice = self.gaussian_conditional.quantize(y_slice, "symbols", mu) - y_hat_slice = y_q_slice + mu - symbols_list.extend(y_q_slice.reshape(-1).tolist()) - indexes_list.extend(indexes.reshape(-1).tolist()) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - - encoder.encode_with_indexes( - symbols_list, indexes_list, cdf, cdf_lengths, offsets - ) - return { - "strings": [encoder.flush()], - "shape": spatial_shape, - "y_hat": torch.cat(y_hat_slices, dim=1), - } - - def decompress( - self, - strings: Sequence[bytes], - shape: Tuple[int, int], - latent_means: Tensor, - latent_scales: Tensor, - **kwargs: Any, - ) -> Dict[str, Any]: - cdf = self.gaussian_conditional.quantized_cdf.tolist() - cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() - offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() - decoder = RansDecoder() - decoder.set_stream(strings[0]) - y_hat_slices: List[Tensor] = [] - - for slice_index in range(self.num_slices): - mu, scale, mean_support = self._slice_params( - slice_index, latent_means, latent_scales, y_hat_slices, shape - ) - indexes = self.gaussian_conditional.build_indexes(scale) - values = decoder.decode_stream( - indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets - ) - y_q_slice = torch.tensor(values, device=mu.device, dtype=mu.dtype).reshape( - mu.shape - ) - y_hat_slice = self.gaussian_conditional.dequantize(y_q_slice, mu) - y_hat_slice = self._apply_lrp(slice_index, mean_support, y_hat_slice) - y_hat_slices.append(y_hat_slice) - - return {"y_hat": torch.cat(y_hat_slices, dim=1)} diff --git a/compressai/latent_codecs/checkerboard.py b/compressai/latent_codecs/checkerboard.py index 1de992c5..0be77f4d 100644 --- a/compressai/latent_codecs/checkerboard.py +++ b/compressai/latent_codecs/checkerboard.py @@ -276,7 +276,7 @@ def decompress( assert all(len(x) == n for x in y_strings_) c, h, w = shape - y_i_shape = (h, w // 2) + y_i_shape = (c, h, w // 2) y_hat_ = side_params.new_zeros((2, n, c, h, w // 2)) side_params_ = self.unembed(side_params) diff --git a/compressai/latent_codecs/entropy_bottleneck.py b/compressai/latent_codecs/entropy_bottleneck.py index f1496f58..cca1048a 100644 --- a/compressai/latent_codecs/entropy_bottleneck.py +++ b/compressai/latent_codecs/entropy_bottleneck.py @@ -80,14 +80,16 @@ def forward(self, y: Tensor) -> Dict[str, Any]: return {"likelihoods": {"y": y_likelihoods}, "y_hat": y_hat} def compress(self, y: Tensor) -> Dict[str, Any]: - shape = y.size()[-2:] + shape = y.shape[1:] y_strings = self.entropy_bottleneck.compress(y) - y_hat = self.entropy_bottleneck.decompress(y_strings, shape) + y_hat = self.entropy_bottleneck.decompress(y_strings, shape[1:]) + assert y_hat.shape[1:] == shape return {"strings": [y_strings], "shape": shape, "y_hat": y_hat} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int], **kwargs + self, strings: List[List[bytes]], shape: Tuple[int, ...], **kwargs ) -> Dict[str, Any]: (y_strings,) = strings - y_hat = self.entropy_bottleneck.decompress(y_strings, shape) + y_hat = self.entropy_bottleneck.decompress(y_strings, shape[1:]) + assert y_hat.shape[1:] == shape return {"y_hat": y_hat} diff --git a/compressai/latent_codecs/gain/hyper.py b/compressai/latent_codecs/gain/hyper.py index 1035c658..cf76067f 100644 --- a/compressai/latent_codecs/gain/hyper.py +++ b/compressai/latent_codecs/gain/hyper.py @@ -105,9 +105,9 @@ def forward(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]: def compress(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]: z = self.h_a(y) z = z * gain - shape = z.size()[-2:] + shape = z.shape[1:] z_strings = self.entropy_bottleneck.compress(z) - z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + z_hat = self.entropy_bottleneck.decompress(z_strings, shape[1:]) z_hat = z_hat * gain_inv params = self.h_s(z_hat) return {"strings": [z_strings], "shape": shape, "params": params} @@ -115,12 +115,12 @@ def compress(self, y: Tensor, gain: Tensor, gain_inv: Tensor) -> Dict[str, Any]: def decompress( self, strings: List[List[bytes]], - shape: Tuple[int, int], + shape: Tuple[int, ...], gain_inv: Tensor, **kwargs, ) -> Dict[str, Any]: (z_strings,) = strings - z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + z_hat = self.entropy_bottleneck.decompress(z_strings, shape[1:]) z_hat = z_hat * gain_inv params = self.h_s(z_hat) return {"params": params} diff --git a/compressai/latent_codecs/gaussian_conditional.py b/compressai/latent_codecs/gaussian_conditional.py index e422f681..3f12423d 100644 --- a/compressai/latent_codecs/gaussian_conditional.py +++ b/compressai/latent_codecs/gaussian_conditional.py @@ -112,12 +112,12 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: y_hat = self.gaussian_conditional.decompress( y_strings, indexes, means=means_hat ) - return {"strings": [y_strings], "shape": y.shape[2:4], "y_hat": y_hat} + return {"strings": [y_strings], "shape": y.shape[1:], "y_hat": y_hat} def decompress( self, strings: List[List[bytes]], - shape: Tuple[int, int], + shape: Tuple[int, ...], ctx_params: Tensor, **kwargs, ) -> Dict[str, Any]: @@ -128,7 +128,7 @@ def decompress( y_hat = self.gaussian_conditional.decompress( y_strings, indexes, means=means_hat ) - assert y_hat.shape[2:4] == shape + assert y_hat.shape[1:] == shape return {"y_hat": y_hat} def _chunk(self, params: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/compressai/latent_codecs/hyper.py b/compressai/latent_codecs/hyper.py index ac53afa7..0154245b 100644 --- a/compressai/latent_codecs/hyper.py +++ b/compressai/latent_codecs/hyper.py @@ -106,16 +106,16 @@ def forward(self, y: Tensor) -> Dict[str, Any]: def compress(self, y: Tensor) -> Dict[str, Any]: z = self.h_a(y) - shape = z.size()[-2:] + shape = z.shape[1:] z_strings = self.entropy_bottleneck.compress(z) - z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + z_hat = self.entropy_bottleneck.decompress(z_strings, shape[1:]) params = self.h_s(z_hat) return {"strings": [z_strings], "shape": shape, "params": params} def decompress( - self, strings: List[List[bytes]], shape: Tuple[int, int], **kwargs + self, strings: List[List[bytes]], shape: Tuple[int, ...], **kwargs ) -> Dict[str, Any]: (z_strings,) = strings - z_hat = self.entropy_bottleneck.decompress(z_strings, shape) + z_hat = self.entropy_bottleneck.decompress(z_strings, shape[1:]) params = self.h_s(z_hat) return {"params": params} diff --git a/compressai/latent_codecs/rasterscan.py b/compressai/latent_codecs/rasterscan.py index c1b32a51..a0a18222 100644 --- a/compressai/latent_codecs/rasterscan.py +++ b/compressai/latent_codecs/rasterscan.py @@ -123,17 +123,17 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: ) y_strings = encoder.flush() ds.append({"strings": [y_strings], "y_hat": y_hat.squeeze(0)}) - return {**default_collate(ds), "shape": y.shape[2:4]} + return {**default_collate(ds), "shape": y.shape[1:]} def decompress( self, strings: List[List[bytes]], - shape: Tuple[int, int], + shape: Tuple[int, ...], ctx_params: Tensor, **kwargs, ) -> Dict[str, Any]: (y_strings,) = strings - y_height, y_width = shape + _, y_height, y_width = shape ds = [] for i in range(len(y_strings)): decoder = RansDecoder() diff --git a/compressai/losses/__init__.py b/compressai/losses/__init__.py index b0863e85..9e8f8ea6 100644 --- a/compressai/losses/__init__.py +++ b/compressai/losses/__init__.py @@ -28,10 +28,12 @@ # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from . import pointcloud +from .cca import CCARateDistortionLoss from .pointcloud import * from .rate_distortion import RateDistortionLoss __all__ = [ *pointcloud.__all__, + "CCARateDistortionLoss", "RateDistortionLoss", ] diff --git a/compressai/losses/cca.py b/compressai/losses/cca.py new file mode 100644 index 00000000..3a4b7009 --- /dev/null +++ b/compressai/losses/cca.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Causal Context Adjustment rate-distortion loss (Han et al., NeurIPS 2024). + +Companion criterion for :class:`compressai.models.cca.CCAModel`. Requires +the model's ``forward`` to return ``aux_likelihoods = {"y_aux", "y_cca"}`` +(populated when ``cca_training=True``) so this loss can add the auxiliary +CCA terms on top of the standard rate-distortion objective. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn + +from pytorch_msssim import ms_ssim + +from compressai.registry import register_criterion + + +@register_criterion("CCARateDistortionLoss") +class CCARateDistortionLoss(nn.Module): + r"""Causal Context Adjustment rate-distortion loss from M. Han, S. Jiang, + S. Li, X. Deng, M. Xu, C. Zhu, S. Gu: `"Causal Context Adjustment Loss + for Learned Image Compression" `_, + Adv. in Neural Information Processing Systems 38 (NeurIPS), 2024. + + Combines the standard rate (``bpp``) and distortion (MSE / MS-SSIM) + terms with the CCA term that measures the gap between the main and + auxiliary causal-context likelihoods produced by + :class:`compressai.models.cca.CCAModel` (with ``cca_training=True``). + + Args: + lmbda: Distortion weight. + metric: Distortion metric, ``"mse"`` or ``"ms-ssim"``. + return_type: ``"all"`` returns the dict of components; otherwise + return the named scalar component (e.g. ``"loss"``). + alpha: Weight on the CCA loss term. + beta: Weight on the bit-rate term. + """ + + def __init__( + self, + lmbda: float = 0.01, + metric: str = "mse", + return_type: str = "all", + alpha: float = 1.0, + beta: float = 1.0, + ) -> None: + super().__init__() + if metric == "mse": + self.metric = nn.MSELoss() + elif metric == "ms-ssim": + self.metric = ms_ssim + else: + raise NotImplementedError(f"{metric} is not implemented!") + + self.lmbda = float(lmbda) + self.alpha = float(alpha) + self.beta = float(beta) + self.return_type = return_type + + def forward(self, output, target): + if "aux_likelihoods" not in output or output["aux_likelihoods"] is None: + raise KeyError( + "output must contain aux_likelihoods for CCARateDistortionLoss; " + "ensure CCAModel was constructed with cca_training=True" + ) + + aux_likelihoods = output["aux_likelihoods"] + if "y_aux" not in aux_likelihoods or "y_cca" not in aux_likelihoods: + raise KeyError("aux_likelihoods must contain y_aux and y_cca") + + batch_size, _, height, width = target.size() + num_pixels = batch_size * height * width + out = {} + + out["cca_loss"] = ( + torch.log(output["likelihoods"]["y"]).sum() / (-math.log(2)) + - torch.log(aux_likelihoods["y_cca"]).sum() / (-math.log(2)) + ) / num_pixels + out["aux2_loss"] = torch.sum( + aux_likelihoods["y_cca"] * torch.log(aux_likelihoods["y_aux"]) + ) / (-math.log(2) * num_pixels) + out["bpp_loss"] = sum( + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) + for likelihoods in output["likelihoods"].values() + ) + + if self.metric == ms_ssim: + out["ms_ssim_loss"] = self.metric(output["x_hat"], target, data_range=1) + distortion = 1 - out["ms_ssim_loss"] + else: + out["mse_loss"] = self.metric(output["x_hat"], target) + distortion = 255**2 * out["mse_loss"] + + out["loss"] = ( + self.lmbda * distortion + + self.beta * out["bpp_loss"] + + self.alpha * out["cca_loss"] + + out["aux2_loss"] + ) + if self.return_type == "all": + return out + return out[self.return_type] diff --git a/compressai/models/_bases/__init__.py b/compressai/models/_bases/__init__.py deleted file mode 100644 index 065119c2..00000000 --- a/compressai/models/_bases/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Abstract base classes shared by multiple slice-based LIC models. - -These were historically hidden behind ``stf_support`` / ``dcae_support`` file -names which obscured the fact that they're real abstract :class:`CompressionModel` -subclasses inherited by 3-4 models each. -""" - -from .slice_entropy import ( - SliceEntropyCompressionModel, - infer_max_support_slices, - infer_num_slices, - lrp_support_channels, - make_entropy_transform, - slice_support_channels, -) - -__all__ = [ - "SliceEntropyCompressionModel", - "infer_max_support_slices", - "infer_num_slices", - "lrp_support_channels", - "make_entropy_transform", - "slice_support_channels", -] diff --git a/compressai/models/_bases/slice_entropy.py b/compressai/models/_bases/slice_entropy.py deleted file mode 100644 index 7159468f..00000000 --- a/compressai/models/_bases/slice_entropy.py +++ /dev/null @@ -1,260 +0,0 @@ -"""Slice-conditional entropy backbone shared by WACNN / SymmetricalTransFormer / MambaVC. - -Promoted out of the historical ``models/stf_support.py`` so the abstract base -class is discoverable by name. Channel-counting helpers and a parameterised -entropy-transform factory live here too — they used to be duplicated across -``stf_support`` / ``ssm_support`` / ``weconvene_support``. -""" - -from __future__ import annotations - -from typing import Dict, Optional, Sequence, Tuple - -import torch.nn as nn - -from torch import Tensor - -from compressai.entropy_models import EntropyBottleneck -from compressai.latent_codecs import ChannelSliceLatentCodec -from compressai.models.utils import conv - -from ..base import CompressionModel - -__all__ = [ - "SliceEntropyCompressionModel", - "infer_max_support_slices", - "infer_num_slices", - "lrp_support_channels", - "make_entropy_transform", - "slice_support_channels", -] - - -_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.cc_mean_transforms." -_KEY_SUFFIX = ".0.weight" - - -def slice_support_channels( - latent_channels: int, - slice_channels: int, - index: int, - max_support_slices: int, -) -> int: - if max_support_slices < 0: - return latent_channels + slice_channels * index - return latent_channels + slice_channels * min(index, max_support_slices) - - -def lrp_support_channels( - latent_channels: int, - slice_channels: int, - index: int, - max_support_slices: int, -) -> int: - if max_support_slices < 0: - return latent_channels + slice_channels * (index + 1) - return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) - - -def make_entropy_transform( - in_channels: int, - out_channels: int, - *, - widths: Sequence[int] = (224, 128), -) -> nn.Sequential: - """Stack of stride-1 3x3 convs with GELU between, used by every slice - entropy model. ``widths`` specifies hidden conv widths; defaults to the - Mamba/WeConvene 3-conv stack. Pass ``widths=(224, 176, 128, 64)`` for the - STF/WACNN 5-conv stack.""" - layers: list[nn.Module] = [] - prev = in_channels - for width in widths: - layers.append(conv(prev, width, stride=1, kernel_size=3)) - layers.append(nn.GELU()) - prev = width - layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) - return nn.Sequential(*layers) - - -def infer_num_slices( - state_dict: Dict[str, Tensor], - *, - prefix: str = _DEFAULT_NUM_SLICES_PREFIX, - suffix: str = _KEY_SUFFIX, -) -> int: - slice_indices = { - int(key[len(prefix) :].split(".", 1)[0]) - for key in state_dict - if key.startswith(prefix) and key.endswith(suffix) - } - return len(slice_indices) - - -def infer_max_support_slices( - state_dict: Dict[str, Tensor], - latent_channels: int, - num_slices: int, - *, - prefix: str = _DEFAULT_NUM_SLICES_PREFIX, - suffix: str = _KEY_SUFFIX, - extra_factor: int = 1, -) -> int: - """Infer ``max_support_slices`` from the input width of the first - cc_mean transform conv. ``extra_factor`` accounts for models like DCAE/SAAF - that prepend additional copies of the latent (``M*3 + slice_channels*N``); - pass ``extra_factor=3`` there. Slice-only models (STF/Mamba*) keep the - default ``extra_factor=1``.""" - slice_channels = latent_channels // num_slices - matching = [ - tensor.size(1) - for key, tensor in state_dict.items() - if key.startswith(prefix) and key.endswith(suffix) - ] - if not matching: - return 0 - max_input_channels = max(matching) - return max( - 0, (max_input_channels - extra_factor * latent_channels) // slice_channels - ) - - -class SliceEntropyCompressionModel(CompressionModel): - """Channel-conditional entropy backbone shared by WACNN, SymmetricalTransFormer, MambaVC. - - Subclasses must populate ``g_a``, ``g_s``, ``h_a``, ``h_mean_s`` and - ``h_scale_s``, then call :meth:`_init_slice_entropy` to wire up the - entropy bottleneck for ``z`` and the :class:`ChannelSliceLatentCodec` - for ``y``. - """ - - h_a: nn.Module - h_mean_s: nn.Module - h_scale_s: nn.Module - entropy_bottleneck: EntropyBottleneck - latent_codec: ChannelSliceLatentCodec - - def _init_slice_entropy( - self, - latent_channels: int, - entropy_bottleneck_channels: int, - num_slices: int, - max_support_slices: int, - mean_support_transforms: Optional[nn.ModuleList] = None, - scale_support_transforms: Optional[nn.ModuleList] = None, - ) -> None: - if latent_channels % num_slices != 0: - raise ValueError("latent_channels must be divisible by num_slices") - if ( - mean_support_transforms is not None - and len(mean_support_transforms) != num_slices - ): - raise ValueError("mean_support_transforms must have num_slices entries") - if ( - scale_support_transforms is not None - and len(scale_support_transforms) != num_slices - ): - raise ValueError("scale_support_transforms must have num_slices entries") - - slice_channels = latent_channels // num_slices - widths = (224, 176, 128, 64) - cc_mean_transforms = nn.ModuleList( - make_entropy_transform( - slice_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - cc_scale_transforms = nn.ModuleList( - make_entropy_transform( - slice_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - lrp_transforms = nn.ModuleList( - make_entropy_transform( - lrp_support_channels( - latent_channels, slice_channels, index, max_support_slices - ), - slice_channels, - widths=widths, - ) - for index in range(num_slices) - ) - - self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) - self.latent_codec = ChannelSliceLatentCodec( - cc_mean_transforms=cc_mean_transforms, - cc_scale_transforms=cc_scale_transforms, - lrp_transforms=lrp_transforms, - mean_support_transforms=mean_support_transforms, - scale_support_transforms=scale_support_transforms, - num_slices=num_slices, - max_support_slices=max_support_slices, - quantizer="ste", - ) - - @property - def num_slices(self) -> int: - return self.latent_codec.num_slices - - @property - def max_support_slices(self) -> int: - return self.latent_codec.max_support_slices - - def _hyper_priors(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - z = self.h_a(y) - z_hat, z_likelihoods = self.entropy_bottleneck(z) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - return z, z_likelihoods, latent_means, latent_scales - - def _forward_latent_output( - self, y: Tensor - ) -> Dict[str, Dict[str, Tensor] | Tensor]: - _, z_likelihoods, latent_means, latent_scales = self._hyper_priors(y) - y_out = self.latent_codec(y, latent_means, latent_scales) - output: Dict[str, Dict[str, Tensor] | Tensor] = { - "y_hat": y_out["y_hat"], - "likelihoods": {"y": y_out["likelihoods"]["y"], "z": z_likelihoods}, - } - return output - - def _forward_latent(self, y: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - output = self._forward_latent_output(y) - return output["y_hat"], output["likelihoods"]["y"], output["likelihoods"]["z"] - - def _compress_latent(self, y: Tensor) -> Dict[str, object]: - z = self.h_a(y) - z_strings = self.entropy_bottleneck.compress(z) - z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - y_out = self.latent_codec.compress(y, latent_means, latent_scales) - return { - "strings": [[y_out["strings"][0]], z_strings], - "shape": z.size()[-2:], - } - - def _decompress_latent( - self, - strings: Sequence[Sequence[bytes]], - shape: Tuple[int, int], - ) -> Tensor: - if len(strings) != 2: - raise ValueError("strings must contain [y_strings, z_strings]") - - z_hat = self.entropy_bottleneck.decompress(strings[1], shape) - latent_means = self.h_mean_s(z_hat) - latent_scales = self.h_scale_s(z_hat) - y_shape = (z_hat.shape[2] * 4, z_hat.shape[3] * 4) - y_out = self.latent_codec.decompress( - strings[0], y_shape, latent_means, latent_scales - ) - return y_out["y_hat"] diff --git a/compressai/models/_helpers/__init__.py b/compressai/models/_helpers/__init__.py new file mode 100644 index 00000000..33ef5a35 --- /dev/null +++ b/compressai/models/_helpers/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Application-layer helpers for assembling model-specific entropy stacks. + +The helpers in this package support STF / WACNN / TCM / CCA wiring, including +split mean / scale channel-context heads and state-dict inference utilities. +They live outside ``compressai.latent_codecs`` because they are model-layer +assembly details, not codec primitives. +""" diff --git a/compressai/models/_helpers/channel_context.py b/compressai/models/_helpers/channel_context.py new file mode 100644 index 00000000..b9e6e600 --- /dev/null +++ b/compressai/models/_helpers/channel_context.py @@ -0,0 +1,166 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Mean / scale split channel-context heads for model-specific entropy stacks. + +The :class:`MeanScaleContextHead` keeps a separate ``mean_cc`` and +``scale_cc`` Sequential — matching the historical ``cc_mean_transforms`` / +``cc_scale_transforms`` ModuleList layout used by STF / WACNN / TCM / +CCA — and concatenates their outputs to form the +``channel_context.y{k}`` entry used by those models. +""" + +from __future__ import annotations + +from typing import Literal, Optional, Union + +import torch +import torch.nn as nn + +from torch import Tensor + +__all__ = [ + "MeanScaleContextHead", +] + + +class MeanScaleContextHead(nn.Module): + """Channel-context head with separate mean / scale sub-networks. + + Internal layout:: + + mean_cc: in_channels -> ... -> slice_ch + scale_cc: in_channels -> ... -> slice_ch + + Forward output is ``cat([scale_cc(...), mean_cc(...)], dim=1)`` of shape + ``(B, 2 * slice_ch, H, W)`` — order matches + :class:`GaussianConditionalLatentCodec` ``chunks=("scales", "means")``. + Optional ``mean_support_transform`` / ``scale_support_transform`` run + independently on the input before the sub-networks (used for SWAtten in + TCM and NAFTransform in CCA). + + When ``side_split > 0`` the head expects its input to be the + concatenation ``cat(latent_means(side_split), latent_scales(side_split), + *prev_y_hat)`` produced by the side-parameter channel-groups path. The + head splits the leading + ``2 * side_split`` channels back into ``latent_means`` / + ``latent_scales`` and routes: + + - ``mean_cc(cat(latent_means, *prev_y_hat))`` + - ``scale_cc(cat(latent_scales, *prev_y_hat))`` + + so each sub-network sees the same input shape it would have under the + pre-refactor STF / WACNN / TCM / CCA wiring (``cc_mean_transforms[k]`` / + ``cc_scale_transforms[k]``). This keeps state-dict weights compatible + with the legacy layout when migrating via + ``convert_*_checkpoint.py``. + + When ``side_split == 0`` (default) the head is generic: ``mean_cc`` and + ``scale_cc`` both see the full input, no internal split. + + When ``emit_mean_support`` is truthy (only meaningful with + ``side_split > 0``) the head appends a copy of the mean-path tensor to + the output, producing + ``cat(scale, mean, mean_support)`` of shape + ``(B, 2*slice_ch + side_split + sum(prev_groups), H, W)``. Two flavours: + + - ``"pre"`` (legacy ``True``) — emit the raw ``mean_in = + cat(latent_means, *prev_y_hat)`` (i.e., before + ``mean_support_transform``). STF / WACNN / TCM use this because their + ``mean_support_transform`` is :class:`Identity` (or the upstream LRP + input is the un-transformed mean_in). + - ``"post"`` — emit ``mean_support_transform(mean_in)`` (the same tensor + that feeds ``mean_cc``). CCA-main / CCA-aux use this because their + upstream ``lrp_transforms`` consume the *post*-NAFTransform mean + support; emitting "pre" would produce wrong LRP outputs even though + the channel widths match. + + The trailing block is consumed by the model-local LRP Gaussian leaf to + recover the upstream LRP input layout (``cat(mean_support, y_hat)``), + enabling byte-for-byte transfer of upstream LRP weights. + """ + + mean_cc: nn.Module + scale_cc: nn.Module + mean_support_transform: nn.Module + scale_support_transform: nn.Module + + def __init__( + self, + mean_cc: nn.Module, + scale_cc: nn.Module, + mean_support_transform: Optional[nn.Module] = None, + scale_support_transform: Optional[nn.Module] = None, + *, + side_split: int = 0, + emit_mean_support: Union[bool, Literal["pre", "post"]] = False, + ) -> None: + super().__init__() + self.mean_cc = mean_cc + self.scale_cc = scale_cc + self.mean_support_transform = mean_support_transform or nn.Identity() + self.scale_support_transform = scale_support_transform or nn.Identity() + self.side_split = int(side_split) + self.emit_mean_support: Literal[False, "pre", "post"] + if emit_mean_support is True: + self.emit_mean_support = "pre" + elif emit_mean_support is False: + self.emit_mean_support = False + elif emit_mean_support in ("pre", "post"): + self.emit_mean_support = emit_mean_support + else: + raise ValueError( + f"emit_mean_support must be False, True, 'pre', or 'post'; " + f"got {emit_mean_support!r}" + ) + if self.emit_mean_support and self.side_split <= 0: + raise ValueError( + "emit_mean_support requires side_split > 0 to recover the " + "legacy mean_support layout cat(latent_means, *prev_y_hat)." + ) + + def forward(self, x: Tensor) -> Tensor: + if self.side_split > 0: + split = self.side_split + latent_means = x[:, :split] + latent_scales = x[:, split : 2 * split] + prev_y_hat = x[:, 2 * split :] + mean_in = torch.cat([latent_means, prev_y_hat], dim=1) + scale_in = torch.cat([latent_scales, prev_y_hat], dim=1) + else: + mean_in = scale_in = x + mean_support = self.mean_support_transform(mean_in) + mean = self.mean_cc(mean_support) + scale = self.scale_cc(self.scale_support_transform(scale_in)) + out = torch.cat([scale, mean], dim=1) + if self.emit_mean_support == "pre": + out = torch.cat([out, mean_in], dim=1) + elif self.emit_mean_support == "post": + out = torch.cat([out, mean_support], dim=1) + return out diff --git a/compressai/models/_helpers/slice_helpers.py b/compressai/models/_helpers/slice_helpers.py new file mode 100644 index 00000000..01fe3bf1 --- /dev/null +++ b/compressai/models/_helpers/slice_helpers.py @@ -0,0 +1,169 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""State-dict inference helpers for model-local channel-groups wiring. + +``_DEFAULT_NUM_SLICES_PREFIX`` reflects the containerised state-dict layout +used by :class:`~compressai.latent_codecs.ChannelGroupsLatentCodec`. +""" + +from __future__ import annotations + +from typing import Dict, Sequence + +import torch.nn as nn + +from torch import Tensor + +from compressai.models.utils import conv + +__all__ = [ + "infer_max_support_slices", + "infer_num_slices", + "lrp_support_channels", + "make_entropy_transform", + "slice_support_channels", +] + + +# Post-refactor state-dict layout: ``HyperpriorLatentCodec`` exposes +# ``ChannelGroupsLatentCodec`` as ``self.y`` (the inner ``self.latent_codec`` +# dict is not a registered nn.Module), so the channel-context entries live +# under ``latent_codec.y.channel_context.y{k}``. Slice 0 has no channel +# context entry in the ELIC layout; side-parameter channel-context models add +# a ``y0`` entry whose presence triggers the auto-detection in +# :func:`infer_num_slices`. +_DEFAULT_NUM_SLICES_PREFIX = "latent_codec.y.channel_context.y" +_DEFAULT_KEY_SUFFIX = ".mean_cc.0.weight" + + +def slice_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * index + return latent_channels + slice_channels * min(index, max_support_slices) + + +def lrp_support_channels( + latent_channels: int, + slice_channels: int, + index: int, + max_support_slices: int, +) -> int: + if max_support_slices < 0: + return latent_channels + slice_channels * (index + 1) + return latent_channels + slice_channels * min(index + 1, max_support_slices + 1) + + +def make_entropy_transform( + in_channels: int, + out_channels: int, + *, + widths: Sequence[int] = (224, 128), +) -> nn.Sequential: + """Stack of stride-1 3x3 convs with GELU activations. + + Used as the ``mean_cc`` / ``scale_cc`` per-group heads and as local + ``lrp_transform`` modules in STF / WACNN / TCM / CCA. ``widths`` + specifies hidden conv widths and defaults to the TCM / CCA 3-conv stack + ``(224, 128)``; pass ``widths=(224, 176, 128, 64)`` for the STF / WACNN + 5-conv stack. + """ + layers: list[nn.Module] = [] + prev = in_channels + for width in widths: + layers.append(conv(prev, width, stride=1, kernel_size=3)) + layers.append(nn.GELU()) + prev = width + layers.append(conv(prev, out_channels, stride=1, kernel_size=3)) + return nn.Sequential(*layers) + + +def infer_num_slices( + state_dict: Dict[str, Tensor], + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _DEFAULT_KEY_SUFFIX, +) -> int: + """Count distinct ``y{k}`` channel-context entries in ``state_dict``. + + Two layouts are supported: + + - ELIC default: channel_context starts at ``y1`` (slice 0 bypasses it), + so the count returned is ``num_slices - 1`` and we add ``1`` to recover + ``num_slices``. + - Side-parameter channel-context layout: channel_context covers every + slice including ``y0``, so the count is already ``num_slices``. + + The two cases are auto-detected by whether ``y0`` appears in the matched + keys. + """ + slice_indices = { + int(key[len(prefix) :].split(".", 1)[0]) + for key in state_dict + if key.startswith(prefix) and key.endswith(suffix) + } + if not slice_indices: + return 0 + if 0 in slice_indices: + return len(slice_indices) + return len(slice_indices) + 1 + + +def infer_max_support_slices( + state_dict: Dict[str, Tensor], + latent_channels: int, + num_slices: int, + *, + prefix: str = _DEFAULT_NUM_SLICES_PREFIX, + suffix: str = _DEFAULT_KEY_SUFFIX, + extra_factor: int = 1, +) -> int: + """Infer ``max_support_slices`` from the input width of the ``mean_cc`` + first conv. ``extra_factor`` accounts for model-layer heads that prepend + additional copies of the latent (``M*extra + slice_channels*N``); default + ``1`` covers STF / WACNN / TCM / CCA heads whose ``mean_cc`` sees one + latent-mean block plus previous-group support. + """ + slice_channels = latent_channels // num_slices + matching = [ + tensor.size(1) + for key, tensor in state_dict.items() + if key.startswith(prefix) and key.endswith(suffix) + ] + if not matching: + return 0 + max_input_channels = max(matching) + return max( + 0, (max_input_channels - extra_factor * latent_channels) // slice_channels + ) diff --git a/compressai/models/cca.py b/compressai/models/cca.py new file mode 100644 index 00000000..106c519f --- /dev/null +++ b/compressai/models/cca.py @@ -0,0 +1,769 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/CVL-UESTC/CCA +# (originally distributed under the MIT License). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Causal Context Adjustment (CCA) standalone autoencoder. + +Mirror of the upstream ``LICAutoencoder`` from +M. Han, S. Jiang, S. Li, X. Deng, M. Xu, C. Zhu, S. Gu: +`"Causal Context Adjustment Loss for Learned Image Compression" +`_, NeurIPS 2024. + +The main entropy stack is a fully containerised +:class:`HyperpriorLatentCodec` with variable-length slice groups and +per-slice :class:`_NAFTransform` support transforms. The optional auxiliary +CCA branch (:class:`_CCAAuxEntropyModel`) is a separate ``nn.Module`` that +re-encodes ``y`` with the skip-most-recent support selection used by +:class:`compressai.losses.CCARateDistortionLoss` to align the causal +context with the rate-distortion objective. +""" + +from __future__ import annotations + +import math + +from itertools import accumulate +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers.layers import conv1x1 +from compressai.models._helpers.channel_context import MeanScaleContextHead +from compressai.models._helpers.slice_helpers import make_entropy_transform +from compressai.models.base import CompressionModel, get_scale_table +from compressai.models.sensetime import ResidualBottleneckBlock +from compressai.models.utils import conv, deconv +from compressai.registry import register_model + +__all__ = [ + "CCAModel", +] + + +class _DualHyperSynthesis(nn.Module): + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) + + +class _LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support + + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + +# ---------------------------------------------------------------------------- +# Slice-size resolver. +# ---------------------------------------------------------------------------- + + +def _resolve_slice_sizes( + latent_channels: int, slice_proportions: Sequence[int] +) -> List[int]: + if len(slice_proportions) == 0: + raise ValueError("slice_proportions must contain at least one entry") + total = sum(slice_proportions) + if total <= 0: + raise ValueError("slice_proportions must sum to a positive integer") + sizes = [ + int(math.floor(latent_channels * proportion / total)) + for proportion in slice_proportions + ] + sizes[-1] += latent_channels - sum(sizes) + if any(size <= 0 for size in sizes): + raise ValueError("resolved slice sizes must all be positive") + return sizes + + +# ---------------------------------------------------------------------------- +# NAF (Non-linear Activation Free) building blocks +# ---------------------------------------------------------------------------- + + +class _SimpleGate(nn.Module): + def forward(self, input_tensor: Tensor) -> Tensor: + gate_tensor, value_tensor = input_tensor.chunk(2, dim=1) + return gate_tensor * value_tensor + + +class _NAFBlock(nn.Module): + """Non-linear Activation Free residual block. + + Used by both the CCA entropy-model auxiliary transforms and the CCA + image-compression model's analysis / synthesis stacks. State-dict keys + (``norm1`` / ``pointwise_depthwise`` / ``channel_attention`` / + ``project`` / ``feed_forward`` / ``beta`` / ``gamma``) match upstream + after ``convert_upstream_cca_state_dict`` (in + ``examples/convert_cca_checkpoint.py``) so released checkpoints load 1:1. + """ + + def __init__(self, channels: int) -> None: + super().__init__() + from timm.layers import LayerNorm2d + + expanded_channels = channels * 2 + self.norm1 = LayerNorm2d(channels) + self.pointwise_depthwise = nn.Sequential( + conv1x1(channels, expanded_channels), + nn.Conv2d( + expanded_channels, + expanded_channels, + kernel_size=3, + padding=1, + groups=expanded_channels, + ), + ) + self.gate = _SimpleGate() + self.channel_attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + conv1x1(channels, channels), + ) + self.project = conv1x1(channels, channels) + self.norm2 = LayerNorm2d(channels) + self.feed_forward = nn.Sequential( + conv1x1(channels, expanded_channels), + _SimpleGate(), + conv1x1(channels, channels), + ) + self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1)) + self.gamma = nn.Parameter(torch.zeros(1, channels, 1, 1)) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.norm1(input_tensor) + output = self.pointwise_depthwise(output) + output = self.gate(output) + output = output * self.channel_attention(output) + output = self.project(output) + output = input_tensor + self.beta * output + return output + self.gamma * self.feed_forward(self.norm2(output)) + + +class _NAFTransform(nn.Module): + """``Conv1x1 -> NAFBlock x N -> Conv1x1`` per-slice support transform.""" + + def __init__( + self, + input_channels: int, + output_channels: int, + hidden_channels: int, + num_layers: int, + ) -> None: + super().__init__() + if num_layers < 1: + raise ValueError("num_layers must be positive") + + self.input_projection = conv1x1(input_channels, hidden_channels) + self.blocks = nn.Sequential( + *(_NAFBlock(hidden_channels) for _ in range(num_layers)) + ) + self.output_projection = conv1x1(hidden_channels, output_channels) + + def forward(self, input_tensor: Tensor) -> Tensor: + output = self.input_projection(input_tensor) + return self.output_projection(output + self.blocks(output)) + + +# ---------------------------------------------------------------------------- +# Analysis / synthesis transforms (NAFBlock + ResidualBottleneckBlock). +# ---------------------------------------------------------------------------- + + +class _CCAEncoder(nn.Module): + """NAFBlock + ResidualBottleneckBlock analysis transform (4 strides).""" + + def __init__( + self, + in_channels: int, + latent_channels: int, + stage_dims: Sequence[int], + stage_layers: Sequence[int], + ) -> None: + super().__init__() + if len(stage_dims) != len(stage_layers): + raise ValueError("stage_dims and stage_layers must have matching length") + self.depth = len(stage_dims) + all_dims = [in_channels, *stage_dims, latent_channels] + self.down = nn.ModuleList( + conv(all_dims[index], all_dims[index + 1], kernel_size=5, stride=2) + for index in range(self.depth + 1) + ) + self.blocks = nn.ModuleList( + nn.Sequential( + *( + ResidualBottleneckBlock(stage_dims[index], stage_dims[index]) + for _ in range(3) + ), + *(_NAFBlock(stage_dims[index]) for _ in range(stage_layers[index])), + ) + for index in range(self.depth) + ) + + def forward(self, x: Tensor) -> Tensor: + for index in range(self.depth): + x = self.down[index](x) + x = self.blocks[index](x) + return self.down[self.depth](x) + + +class _CCADecoder(nn.Module): + """NAFBlock + ResidualBottleneckBlock synthesis transform (4 strides).""" + + def __init__( + self, + out_channels: int, + latent_channels: int, + stage_dims: Sequence[int], + stage_layers: Sequence[int], + ) -> None: + super().__init__() + if len(stage_dims) != len(stage_layers): + raise ValueError("stage_dims and stage_layers must have matching length") + self.depth = len(stage_dims) + all_dims = [out_channels, *stage_dims, latent_channels] + self.up = nn.ModuleList( + deconv(all_dims[index + 1], all_dims[index], kernel_size=5, stride=2) + for index in reversed(range(self.depth + 1)) + ) + self.blocks = nn.ModuleList( + nn.Sequential( + *(_NAFBlock(stage_dims[index]) for _ in range(stage_layers[index])), + *( + ResidualBottleneckBlock(stage_dims[index], stage_dims[index]) + for _ in range(3) + ), + ) + for index in reversed(range(self.depth)) + ) + + def forward(self, x: Tensor) -> Tensor: + for index in range(self.depth): + x = self.up[index](x) + x = self.blocks[index](x) + return self.up[self.depth](x) + + +# ---------------------------------------------------------------------------- +# Auxiliary CCA entropy branch. +# ---------------------------------------------------------------------------- + + +class _CCAAuxEntropyModel(nn.Module): + """Auxiliary CCA entropy branch (skip-most-recent-slice support). + + Produces the ``y_aux`` (factorised) and ``y_cca`` (Gaussian-conditional) + likelihoods used by :class:`compressai.losses.CCARateDistortionLoss`. + + Mirrors the upstream ``AuxEntropyModel`` in + ``candidate/CCA/models/aux_em.py``: for slice ``i`` the support is + ``cat(latent_means, *y_hat_slices[: max(i - 1, 0)])`` (i.e., skip the + *most recent* decoded slice). This is wired inline (ELIC-style) on a + private side-parameter channel-groups codec with matching per-slice + ``support_count`` to size the channel-context heads. + + Although upstream only *uses* the LRP path on the first ``num_slices - + 2`` slices, the published checkpoints carry LRP weights for *all* + slices. To strict-load those checkpoints every leaf is a Gaussian codec + with local LRP refinement; the LRP applied to the trailing two slices is + benign (those slices' ``y_hat`` is excluded from every later slice's + skip-most-recent support selection, so it never feeds back into the + likelihoods). + """ + + def __init__( + self, + latent_channels: int, + slice_sizes: Sequence[int], + hidden_channels: int, + num_layers: int, + ) -> None: + super().__init__() + self.latent_channels = int(latent_channels) + self.slice_sizes: List[int] = list(map(int, slice_sizes)) + self.num_slices = len(self.slice_sizes) + self.hidden_channels = int(hidden_channels) + self.num_layers = int(num_layers) + + M = self.latent_channels + slice_sizes = self.slice_sizes + em_hidden_channels = self.hidden_channels + em_num_layers = self.num_layers + cumulative = list(accumulate(slice_sizes, initial=0)) + widths = (em_hidden_channels, 128) + + # Skip-most-recent support: slice k sees max(k - 1, 0) prior slices. + def support_count(k: int) -> int: + return max(k - 1, 0) + + def mean_support_ch(k: int) -> int: + return M + cumulative[support_count(k)] + + support_slices = [list(range(support_count(k))) for k in range(self.num_slices)] + + def naf_factory(c_in: int, c_out: int) -> nn.Module: + return _NAFTransform(c_in, c_out, em_hidden_channels, em_num_layers) + + # Side-parameter channel-groups wiring, inlined ELIC-style. Differs + # from the main CCA stack only in the skip-most-recent support count. + channel_context = { + f"y{k}": MeanScaleContextHead( + mean_cc=make_entropy_transform( + mean_support_ch(k), slice_sizes[k], widths=widths + ), + scale_cc=make_entropy_transform( + mean_support_ch(k), slice_sizes[k], widths=widths + ), + mean_support_transform=naf_factory( + mean_support_ch(k), mean_support_ch(k) + ), + scale_support_transform=naf_factory( + mean_support_ch(k), mean_support_ch(k) + ), + side_split=M, + emit_mean_support="post", + ) + for k in range(self.num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + mean_support_ch(k) + slice_sizes[k], slice_sizes[k], widths=widths + ), + mean_support_trail_channels=mean_support_ch(k), + quantizer="ste", + ) + for k in range(self.num_slices) + } + + self.y_entropy_bottleneck = EntropyBottleneck(M) + self.inner_codec = _SideContextChannelGroupsLatentCodec( + groups=list(slice_sizes), + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ) + + def forward( + self, + y: Tensor, + latent_means: Tensor, + latent_scales: Tensor, + ) -> Dict[str, Tensor]: + _, y_aux_likelihoods = self.y_entropy_bottleneck(y) + side_params = torch.cat([latent_means, latent_scales], dim=1) + inner_out = self.inner_codec(y, side_params) + return { + "y_aux": y_aux_likelihoods, + "y_cca": inner_out["likelihoods"]["y"], + } + + +# ---------------------------------------------------------------------------- +# Top-level CCAModel. +# ---------------------------------------------------------------------------- + + +@register_model("cca") +class CCAModel(CompressionModel): + r"""Causal Context Adjustment standalone autoencoder. + + Mirrors the upstream ``LICAutoencoder`` from M. Han et al., NeurIPS 2024 + (`Causal Context Adjustment Loss for Learned Image Compression + `_). + + The entropy stack is a :class:`HyperpriorLatentCodec` with variable-length + channel groups (``slice_proportions``), per-slice :class:`_NAFTransform` + support transforms, and a STE-quantised ``z`` leaf. When + ``cca_training=True`` an auxiliary + :class:`_CCAAuxEntropyModel` branch is added that produces ``y_aux`` / + ``y_cca`` likelihoods consumed by + :class:`compressai.losses.CCARateDistortionLoss`. + + Args: + latent_channels: Number of channels in the latent (``M``). + hyper_channels: Number of channels in the hyper-latent (``N_z``). + slice_proportions: Per-slice channel proportions; the actual slice + channel widths are computed as + ``floor(latent_channels * p / sum(p))`` with the residual added + to the last slice. Pass ``[1] * num_slices`` for equal-sized + slices; pass ``[8, 28, 56, 92, 136]`` to reproduce the upstream + published M=320 layout. + encoder_dims: Per-stage feature widths for the analysis transform + (3 stages by default). + encoder_layers: Per-stage NAFBlock counts for the analysis transform. + em_hidden_channels: Hidden width inside the per-slice NAFTransforms + and channel-context heads. + em_num_layers: NAFBlock count inside each per-slice NAFTransform. + cca_training: When ``True``, allocate the auxiliary CCA entropy + branch so that ``forward`` populates ``aux_likelihoods``. + """ + + def __init__( + self, + latent_channels: int = 320, + hyper_channels: int = 192, + slice_proportions: Sequence[int] = (8, 28, 56, 92, 136), + encoder_dims: Sequence[int] = (192, 224, 256), + encoder_layers: Sequence[int] = (4, 4, 4), + em_hidden_channels: int = 224, + em_num_layers: int = 4, + cca_training: bool = False, + **kwargs, + ) -> None: + super().__init__(**kwargs) + encoder_dims = tuple(encoder_dims) + encoder_layers = tuple(encoder_layers) + slice_proportions = tuple(int(value) for value in slice_proportions) + + self.M = int(latent_channels) + self.N = int(hyper_channels) + self.encoder_dims = encoder_dims + self.encoder_layers = encoder_layers + self.slice_proportions = slice_proportions + self.em_hidden_channels = int(em_hidden_channels) + self.em_num_layers = int(em_num_layers) + self.cca_training = bool(cca_training) + + self.slice_sizes: List[int] = _resolve_slice_sizes(self.M, slice_proportions) + self.num_slices = len(self.slice_sizes) + + self.g_a = _CCAEncoder(3, self.M, encoder_dims, encoder_layers) + self.g_s = _CCADecoder(3, self.M, encoder_dims, encoder_layers) + + last_encoder_dim = encoder_dims[-1] + h_a = nn.Sequential( + conv(self.M, last_encoder_dim, kernel_size=3, stride=1), + nn.GELU(), + conv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + conv(last_encoder_dim, self.N, kernel_size=5, stride=2), + ) + h_mean_s = nn.Sequential( + deconv(self.N, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, self.M, kernel_size=3, stride=1), + ) + h_scale_s = nn.Sequential( + deconv(self.N, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, last_encoder_dim, kernel_size=5, stride=2), + nn.GELU(), + deconv(last_encoder_dim, self.M, kernel_size=3, stride=1), + ) + + # Main entropy stack, wired inline (ELIC-style). Distinctive choices + # vs. STF/WACNN/TCM: + # + # - ``groups`` is the variable-length ``slice_sizes`` (resolved from + # ``slice_proportions``); STF / WACNN / TCM use uniform ``[M//K]*K``. + # - the per-slice mean / scale support transforms are _NAFTransform + # instances (vs. STF identity / TCM SWAtten), with + # ``emit_mean_support="post"`` so the LRP head receives the + # *post*-NAFTransform mean support — replicating the upstream LIC + # LRP layout for byte-for-byte weight transfer. + # - the ``z`` leaf uses ``EntropyBottleneckLatentCodec(quantizer="ste")`` + # to recover upstream's ``quantize_ste(z - z_offset) + z_offset`` + # behaviour without a model-side hack. + M = self.M + slice_sizes = self.slice_sizes + cumulative = list(accumulate(slice_sizes, initial=0)) + widths = (self.em_hidden_channels, 128) + + # use-all-prior support; matches upstream LIC main path (no skip). + def mean_support_ch(k: int) -> int: + # cat(latent_means(M), *prev_y_hat(sum(slice_sizes[:k]))). + return M + cumulative[k] + + def naf_factory(c_in: int, c_out: int) -> nn.Module: + return _NAFTransform( + c_in, c_out, self.em_hidden_channels, self.em_num_layers + ) + + channel_context = { + f"y{k}": MeanScaleContextHead( + mean_cc=make_entropy_transform( + mean_support_ch(k), slice_sizes[k], widths=widths + ), + scale_cc=make_entropy_transform( + mean_support_ch(k), slice_sizes[k], widths=widths + ), + mean_support_transform=naf_factory( + mean_support_ch(k), mean_support_ch(k) + ), + scale_support_transform=naf_factory( + mean_support_ch(k), mean_support_ch(k) + ), + side_split=M, + emit_mean_support="post", + ) + for k in range(self.num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + mean_support_ch(k) + slice_sizes[k], slice_sizes[k], widths=widths + ), + mean_support_trail_channels=mean_support_ch(k), + quantizer="ste", + ) + for k in range(self.num_slices) + } + + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(self.N), + quantizer="ste", + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=list(slice_sizes), + channel_context=channel_context, + latent_codec=y_latent_codec, + ), + }, + ) + + if self.cca_training: + self.aux_entropy_model = _CCAAuxEntropyModel( + self.M, + self.slice_sizes, + self.em_hidden_channels, + self.em_num_layers, + ) + + def forward(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec(y) + result: Dict[str, object] = { + "y": y, + "x_hat": self.g_s(y_out["y_hat"]), + "likelihoods": y_out["likelihoods"], + } + if self.cca_training: + # ``self.latent_codec.h_s`` concatenates the dual hyper-synthesis heads; its + # output is ``cat(latent_means, latent_scales)`` of width 2*M. + # Recover them from the inner ``z`` round-trip so the aux + # branch sees the same hyperprior context as the main path. + z_out = self.latent_codec.latent_codec["z"](self.latent_codec.h_a(y)) + side_params = self.latent_codec.h_s(z_out["y_hat"]) + latent_means, latent_scales = torch.split(side_params, self.M, dim=1) + result["aux_likelihoods"] = self.aux_entropy_model( + y, latent_means, latent_scales + ) + else: + result["aux_likelihoods"] = None + return result + + def compress(self, x: Tensor) -> Dict[str, object]: + y = self.g_a(x) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} + + def decompress( + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Tuple[int, ...]], + ) -> Dict[str, Tensor]: + y_out = self.latent_codec.decompress(strings, shape) + return {"x_hat": self.g_s(y_out["y_hat"]).clamp_(0, 1)} + + def update( + self, scale_table: Optional[Tensor] = None, force: bool = False, **kwargs + ) -> bool: + if scale_table is None: + scale_table = get_scale_table() + return super().update(scale_table=scale_table, force=force, **kwargs) + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "CCAModel": + cfg = _infer_config_from_state_dict(state_dict) + net = cls(**cfg) + net.load_state_dict(state_dict) + return net + + +# ---------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters). +# ---------------------------------------------------------------------------- + + +def _infer_config_from_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, object]: + """Recover constructor kwargs from a compressai-layout CCA state dict.""" + encoder_dims = ( + state_dict["g_a.down.0.weight"].size(0), + state_dict["g_a.down.1.weight"].size(0), + state_dict["g_a.down.2.weight"].size(0), + ) + latent_channels = state_dict["g_a.down.3.weight"].size(0) + hyper_channels = state_dict["latent_codec.h_a.4.weight"].size(0) + + encoder_layers: List[int] = [] + for stage in range(3): + index = 0 + while f"g_a.blocks.{stage}.{index}.beta" in state_dict or _has_resblock( + state_dict, stage, index + ): + index += 1 + encoder_layers.append(index - 3) + + cc_keys = [ + key + for key in state_dict + if key.startswith("latent_codec.y.channel_context.y") + and key.endswith(".mean_cc.4.weight") + ] + cc_keys.sort(key=lambda key: int(key.split(".")[3][1:])) # ".y{k}." -> k + if not cc_keys: + raise RuntimeError("state dict does not contain channel-context mean_cc heads") + slice_sizes = [int(state_dict[key].size(0)) for key in cc_keys] + + em_hidden_channels = int( + state_dict[ + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + ].size(0) + ) + + em_num_layers = 0 + while ( + f"latent_codec.y.channel_context.y0.mean_support_transform.blocks.{em_num_layers}.beta" + in state_dict + ): + em_num_layers += 1 + + cca_training = any(key.startswith("aux_entropy_model.") for key in state_dict) + + return { + "latent_channels": int(latent_channels), + "hyper_channels": int(hyper_channels), + "slice_proportions": tuple(slice_sizes), + "encoder_dims": tuple(int(value) for value in encoder_dims), + "encoder_layers": tuple(int(value) for value in encoder_layers), + "em_hidden_channels": em_hidden_channels, + "em_num_layers": em_num_layers, + "cca_training": cca_training, + } + + +def _has_resblock(state_dict: Dict[str, Tensor], stage: int, sub_index: int) -> bool: + return f"g_a.blocks.{stage}.{sub_index}.conv2.weight" in state_dict and ( + f"g_a.blocks.{stage}.{sub_index}.beta" not in state_dict + ) diff --git a/compressai/models/pointcloud/sfu_pointnet.py b/compressai/models/pointcloud/sfu_pointnet.py index a1270d63..0512b766 100644 --- a/compressai/models/pointcloud/sfu_pointnet.py +++ b/compressai/models/pointcloud/sfu_pointnet.py @@ -123,11 +123,12 @@ def compress(self, input): y = self.g_a(x_t) y_out = self.latent_codec.compress(y) [y_strings] = y_out["strings"] - return {"strings": [y_strings], "shape": (1,)} + return {"strings": [y_strings], "shape": y_out["shape"]} def decompress(self, strings, shape): assert isinstance(strings, list) and len(strings) == 1 [y_strings] = strings - y_hat = self.latent_codec.decompress([y_strings], shape) + y_out = self.latent_codec.decompress([y_strings], shape) + y_hat = y_out["y_hat"] x_hat = self.g_s(y_hat) return {"x_hat": x_hat} diff --git a/compressai/models/sensetime.py b/compressai/models/sensetime.py index 2596f012..ae54d15b 100644 --- a/compressai/models/sensetime.py +++ b/compressai/models/sensetime.py @@ -27,8 +27,6 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import types - import torch import torch.nn as nn @@ -455,6 +453,10 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): ) for k in range(1, len(self.groups)) } + support_slices = [ + [] if k == 0 else [0] if k == 1 else [0, k - 1] + for k in range(len(self.groups)) + ] # In [He2022], this is labeled "g_sp^(k)". spatial_context = [ @@ -514,29 +516,11 @@ def __init__(self, N=192, M=320, groups=None, **kwargs): groups=self.groups, channel_context=channel_context, latent_codec=scctx_latent_codec, + support_slices=support_slices, ), }, ) - self._monkey_patch() - - def _monkey_patch(self): - """Monkey-patch to use only first group and most recent group.""" - - def merge_y(self: ChannelGroupsLatentCodec, *args): - if len(args) == 0: - return Tensor() - if len(args) == 1: - return args[0] - if len(args) < len(self.groups): - return torch.cat([args[0], args[-1]], dim=1) - return torch.cat(args, dim=1) - - assert isinstance(self.latent_codec, HyperpriorLatentCodec) - obj = self.latent_codec.y - assert isinstance(obj, ChannelGroupsLatentCodec) - obj.merge_y = types.MethodType(merge_y, obj) - @classmethod def from_state_dict(cls, state_dict): """Return a new model instance from `state_dict`.""" diff --git a/compressai/models/stf.py b/compressai/models/stf.py index 5074d5b6..e01ac586 100644 --- a/compressai/models/stf.py +++ b/compressai/models/stf.py @@ -37,36 +37,130 @@ import math -from typing import Dict, Optional, Sequence, Tuple, Type +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import torch import torch.nn as nn -from timm.layers import DropPath, Mlp from timm.models.swin_transformer import SwinTransformerBlock as _TimmSwinBlock from torch import Tensor -from compressai.layers import GDN, conv1x1, conv3x3, subpel_conv3x3 +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers import GDN, conv3x3, subpel_conv3x3 from compressai.layers.attn import ( PatchMerging, PatchSplit, WinNoShiftAttention, ) -from compressai.models._bases import ( - SliceEntropyCompressionModel, +from compressai.models._helpers.channel_context import MeanScaleContextHead +from compressai.models._helpers.slice_helpers import ( infer_max_support_slices, infer_num_slices, + make_entropy_transform, ) +from compressai.models.base import CompressionModel, SimpleVAECompressionModel from compressai.models.utils import conv, deconv from compressai.registry import register_model __all__ = [ "SymmetricalTransFormer", "WACNN", - "convert_upstream_stf_state_dict", ] +class _DualHyperSynthesis(nn.Module): + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) + + +class _LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support + + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + # ---------------------------------------------------------------------------- # STF building blocks # (formerly compressai/layers/lic/stf.py; private to the WACNN / SymmetricalTransFormer models) @@ -202,61 +296,8 @@ def forward(self, input_tensor: Tensor) -> Tensor: # ---------------------------------------------------------------------------- -_UPSTREAM_LATENT_CODEC_PREFIXES = ( - "cc_mean_transforms", - "cc_scale_transforms", - "lrp_transforms", - "gaussian_conditional", -) - - -def convert_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: - """Translate a candidate ``STF`` / ``WACNN`` state dict into compressai layout. - - Upstream checkpoints (``stf__best.pth.tar`` / ``cnn__best.pth.tar`` - from `Zou et al. 2022 `_) are saved from a - ``DataParallel``-wrapped module and place the channel-conditional entropy - transforms at the model root. compressai houses those transforms (plus the - Gaussian conditional) under ``latent_codec.*``. This helper: - - - strips the leading ``module.`` prefix added by ``DataParallel``; - - re-roots ``cc_mean_transforms`` / ``cc_scale_transforms`` / - ``lrp_transforms`` / ``gaussian_conditional`` under ``latent_codec.``; - - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / ``syn_layers`` - / ``end_conv`` / ``h_a`` / ``h_mean_s`` / ``h_scale_s`` / - ``entropy_bottleneck`` keys unchanged. - - The returned dict can be loaded by :meth:`WACNN.from_state_dict` or - :meth:`SymmetricalTransFormer.from_state_dict`. Both ``from_state_dict`` - entry points auto-detect the upstream layout and call this helper, so - direct invocation is only needed when persisting the converted dict. - """ - converted: Dict[str, Tensor] = {} - for key, value in state_dict.items(): - new_key = key[len("module.") :] if key.startswith("module.") else key - head = new_key.split(".", 1)[0] - if head in _UPSTREAM_LATENT_CODEC_PREFIXES: - new_key = "latent_codec." + new_key - converted[new_key] = value - return converted - - -def _is_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> bool: - """Heuristic: upstream checkpoints either carry a ``module.`` prefix or - place ``cc_mean_transforms`` at the root instead of under ``latent_codec``. - """ - for key in state_dict: - if key.startswith("module."): - return True - if key.startswith("cc_mean_transforms.") or key.startswith( - "gaussian_conditional." - ): - return True - return False - - @register_model("stf-wacnn") -class WACNN(SliceEntropyCompressionModel): +class WACNN(SimpleVAECompressionModel): r"""WACNN model from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image Compression" `_, IEEE/CVF Conf. on Computer Vision @@ -267,6 +308,14 @@ class WACNN(SliceEntropyCompressionModel): ``output_proj=False``) inside the analysis/synthesis transforms, paired with a Minnen2020-style channel-wise autoregressive entropy model. + The entropy stack is a fully containerised + :class:`HyperpriorLatentCodec` that owns ``h_a``, ``h_s``, the ``z`` + bottleneck and the per-slice side-conditioned channel-groups path. The + codec is wired inline in ``__init__`` (ELIC-style) rather than behind a factory: the + ``channel_context`` heads are :class:`MeanScaleContextHead` instances + (split mean / scale, ``emit_mean_support=True``) and the per-slice + leaves are STE-quantised Gaussian codecs with local LRP refinement. + Args: N (int): Number of channels in the hyperprior backbone. M (int): Number of channels in the latent representation. @@ -282,6 +331,10 @@ def __init__( **kwargs, ) -> None: super().__init__(**kwargs) + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + slice_ch = M // num_slices + self.g_a = nn.Sequential( conv(3, N, kernel_size=5, stride=2), GDN(N), @@ -312,7 +365,8 @@ def __init__( GDN(N, inverse=True), deconv(N, 3, kernel_size=5, stride=2), ) - self.h_a = nn.Sequential( + + h_a = nn.Sequential( conv3x3(M, M), nn.GELU(), conv3x3(M, 288), @@ -323,7 +377,7 @@ def __init__( nn.GELU(), conv3x3(224, N, stride=2), ) - self.h_mean_s = nn.Sequential( + h_mean_s = nn.Sequential( conv3x3(N, N), nn.GELU(), subpel_conv3x3(N, 224, 2), @@ -334,7 +388,7 @@ def __init__( nn.GELU(), conv3x3(288, M), ) - self.h_scale_s = nn.Sequential( + h_scale_s = nn.Sequential( conv3x3(N, N), nn.GELU(), subpel_conv3x3(N, 224, 2), @@ -345,33 +399,66 @@ def __init__( nn.GELU(), conv3x3(288, M), ) - self._init_slice_entropy( - M, - N, - num_slices, - max_support_slices, - ) - def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: - y = self.g_a(x) - latent_output = self._forward_latent_output(y) - return { - "x_hat": self.g_s(latent_output["y_hat"]), - "likelihoods": latent_output["likelihoods"], + widths = (224, 176, 128, 64) + groups = [slice_ch] * num_slices + + def support_count(k: int) -> int: + return k if max_support_slices < 0 else min(k, max_support_slices) + + # mean_support = cat(latent_means(M), *prev_y_hat(slice_ch * support_count)). + def mean_support_ch(k: int) -> int: + return M + slice_ch * support_count(k) + + support_slices = [list(range(support_count(k))) for k in range(num_slices)] + + # Each head sees cat(side_params(2M), *prev_y_hat) and emits + # cat(scale, mean, mean_support) for the LRP-aware leaf to consume. + channel_context = { + f"y{k}": MeanScaleContextHead( + mean_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + scale_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + side_split=M, + emit_mean_support=True, + ) + for k in range(num_slices) + } + # Per-slice leaves: LRP transform reads cat(mean_support, y_hat) off + # the trailing block of ctx_params; upstream lrp_transforms.{k} + # weights transfer byte-for-byte (see convert_upstream_stf_state_dict). + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + mean_support_ch(k) + slice_ch, slice_ch, widths=widths + ), + mean_support_trail_channels=mean_support_ch(k), + quantizer="ste", + ) + for k in range(num_slices) } - def compress(self, x: Tensor) -> Dict[str, object]: - return self._compress_latent(self.g_a(x)) - - def decompress( - self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] - ) -> Dict[str, Tensor]: - return {"x_hat": self.g_s(self._decompress_latent(strings, shape)).clamp_(0, 1)} + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), quantizer="ste" + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=groups, + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ), + }, + ) @classmethod def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": - if _is_upstream_stf_state_dict(state_dict): - state_dict = convert_upstream_stf_state_dict(state_dict) N = state_dict["g_a.0.weight"].size(0) M = state_dict["g_a.7.weight"].size(0) num_slices = infer_num_slices(state_dict) or 10 @@ -387,7 +474,7 @@ def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "WACNN": @register_model("stf") -class SymmetricalTransFormer(SliceEntropyCompressionModel): +class SymmetricalTransFormer(CompressionModel): r"""Symmetrical Transformer model (STF) from R. Zou, C. Song, Z. Zhang: `"The Devil Is in the Details: Window-based Attention for Image Compression" `_, IEEE/CVF Conf. on @@ -395,7 +482,10 @@ class SymmetricalTransFormer(SliceEntropyCompressionModel): Transformer-based companion of :class:`WACNN` that builds the analysis/synthesis transforms with stacked Swin-style basic layers and a - channel-wise autoregressive entropy model. + channel-wise autoregressive entropy model. The entropy stack mirrors + :class:`WACNN`'s containerised side-conditioned + :class:`HyperpriorLatentCodec`, with widths derived from the transformer's + stage channel counts. Args: embed_dim (int): Patch-embedding dimension. @@ -502,7 +592,14 @@ def __init__( latent_channels = int(embed_dim * 2 ** (self.num_layers - 1)) bottleneck_channels = latent_channels // 2 - self.h_a = nn.Sequential( + if latent_channels % num_slices != 0: + raise ValueError("latent_channels must be divisible by num_slices") + slice_ch = latent_channels // num_slices + resolved_max_support = ( + num_slices // 2 if max_support_slices is None else max_support_slices + ) + + h_a = nn.Sequential( conv3x3(latent_channels, latent_channels), nn.GELU(), conv3x3(latent_channels, latent_channels - embed_dim), @@ -515,7 +612,7 @@ def __init__( nn.GELU(), conv3x3(latent_channels - 3 * embed_dim, bottleneck_channels, stride=2), ) - self.h_mean_s = nn.Sequential( + h_mean_s = nn.Sequential( conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), nn.GELU(), subpel_conv3x3( @@ -528,7 +625,7 @@ def __init__( nn.GELU(), conv3x3(latent_channels, latent_channels), ) - self.h_scale_s = nn.Sequential( + h_scale_s = nn.Sequential( conv3x3(bottleneck_channels, latent_channels - 3 * embed_dim), nn.GELU(), subpel_conv3x3( @@ -541,11 +638,60 @@ def __init__( nn.GELU(), conv3x3(latent_channels, latent_channels), ) - self._init_slice_entropy( - latent_channels, - bottleneck_channels, - num_slices, - num_slices // 2 if max_support_slices is None else max_support_slices, + + N = bottleneck_channels + M = latent_channels + widths = (224, 176, 128, 64) + groups = [slice_ch] * num_slices + + def support_count(k: int) -> int: + return k if resolved_max_support < 0 else min(k, resolved_max_support) + + def mean_support_ch(k: int) -> int: + return M + slice_ch * support_count(k) + + support_slices = [list(range(support_count(k))) for k in range(num_slices)] + + # Side-parameter channel-groups wiring, inlined ELIC-style (see WACNN + # for the per-key shape rationale). + channel_context = { + f"y{k}": MeanScaleContextHead( + mean_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + scale_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + side_split=M, + emit_mean_support=True, + ) + for k in range(num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + mean_support_ch(k) + slice_ch, slice_ch, widths=widths + ), + mean_support_trail_channels=mean_support_ch(k), + quantizer="ste", + ) + for k in range(num_slices) + } + + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(N), quantizer="ste" + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=groups, + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ), + }, ) def _analysis_transform(self, x: Tensor) -> Tuple[Tensor, int, int]: @@ -576,27 +722,29 @@ def _synthesis_transform(self, y_hat: Tensor, height: int, width: int) -> Tensor def forward(self, x: Tensor) -> Dict[str, Dict[str, Tensor] | Tensor]: y, height, width = self._analysis_transform(x) - latent_output = self._forward_latent_output(y) + y_out = self.latent_codec(y) return { - "x_hat": self._synthesis_transform(latent_output["y_hat"], height, width), - "likelihoods": latent_output["likelihoods"], + "x_hat": self._synthesis_transform(y_out["y_hat"], height, width), + "likelihoods": y_out["likelihoods"], } def compress(self, x: Tensor) -> Dict[str, object]: y, _, _ = self._analysis_transform(x) - return self._compress_latent(y) + y_out = self.latent_codec.compress(y) + return {"strings": y_out["strings"], "shape": y_out["shape"]} def decompress( - self, strings: Sequence[Sequence[bytes]], shape: Tuple[int, int] + self, + strings: Sequence[Sequence[bytes]], + shape: Dict[str, Tuple[int, ...]] | Tuple[int, int], ) -> Dict[str, Tensor]: - y_hat = self._decompress_latent(strings, shape) + y_out = self.latent_codec.decompress(strings, shape) + y_hat = y_out["y_hat"] height, width = y_hat.shape[2:] return {"x_hat": self._synthesis_transform(y_hat, height, width).clamp_(0, 1)} @classmethod def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "SymmetricalTransFormer": - if _is_upstream_stf_state_dict(state_dict): - state_dict = convert_upstream_stf_state_dict(state_dict) patch_size = state_dict["patch_embed.proj.weight"].size(2) embed_dim = state_dict["patch_embed.proj.weight"].size(0) layer_indices = sorted( diff --git a/compressai/models/tcm.py b/compressai/models/tcm.py new file mode 100644 index 00000000..7a80609c --- /dev/null +++ b/compressai/models/tcm.py @@ -0,0 +1,530 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. +# +# This file adapts code from https://github.com/jmliu206/LIC_TCM +# (originally distributed under the MIT License). The upstream copyright +# notice is preserved in that repository; modifications by InterDigital +# Communications, Inc. are released under the BSD 3-Clause Clear License +# terms below. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from __future__ import annotations + +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn + +from torch import Tensor + +from compressai.entropy_models import EntropyBottleneck +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + EntropyBottleneckLatentCodec, + GaussianConditionalLatentCodec, + HyperpriorLatentCodec, +) +from compressai.layers import ( + ResidualBlockUpsample, + ResidualBlockWithStride, + conv3x3, + subpel_conv3x3, +) +from compressai.layers.attn import ConvTransBlock, SWAtten +from compressai.models._helpers.channel_context import MeanScaleContextHead +from compressai.models._helpers.slice_helpers import ( + infer_max_support_slices, + infer_num_slices, + make_entropy_transform, +) +from compressai.models.base import SimpleVAECompressionModel +from compressai.registry import register_model + +__all__ = [ + "TCM", +] + + +class _DualHyperSynthesis(nn.Module): + h_mean_s: nn.Module + h_scale_s: nn.Module + + def __init__(self, h_mean_s: nn.Module, h_scale_s: nn.Module) -> None: + super().__init__() + self.h_mean_s = h_mean_s + self.h_scale_s = h_scale_s + + def forward(self, z_hat: Tensor) -> Tensor: + return torch.cat([self.h_mean_s(z_hat), self.h_scale_s(z_hat)], dim=1) + + +class _LRPGaussianLatentCodec(GaussianConditionalLatentCodec): + lrp_transform: nn.Module + + def __init__( + self, + lrp_transform: nn.Module, + *, + lrp_scale: float = 0.5, + mean_support_trail_channels: int = 0, + **gc_kwargs: Any, + ) -> None: + super().__init__(**gc_kwargs) + self.lrp_transform = lrp_transform + self.lrp_scale = float(lrp_scale) + self.mean_support_trail_channels = int(mean_support_trail_channels) + + def _split_ctx_params(self, ctx_params: Tensor) -> Tuple[Tensor, Tensor]: + if self.mean_support_trail_channels <= 0: + return ctx_params, ctx_params + trail = self.mean_support_trail_channels + gaussian_params = ctx_params[:, :-trail] + mean_support = ctx_params[:, -trail:] + return gaussian_params, mean_support + + def _apply_lrp(self, mean_support: Tensor, y_hat: Tensor) -> Tensor: + lrp = self.lrp_scale * torch.tanh( + self.lrp_transform(torch.cat([mean_support, y_hat], dim=1)) + ) + return y_hat + lrp + + def forward(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().forward(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().compress(y, gaussian_params) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + def decompress( + self, + strings: List[List[bytes]], + shape: Tuple[int, ...], + ctx_params: Tensor, + **kwargs: Any, + ) -> Dict[str, Any]: + gaussian_params, mean_support = self._split_ctx_params(ctx_params) + out = super().decompress(strings, shape, gaussian_params, **kwargs) + out["y_hat"] = self._apply_lrp(mean_support, out["y_hat"]) + return out + + +class _SideContextChannelGroupsLatentCodec(ChannelGroupsLatentCodec): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + if "y0" not in self.channel_context: + raise ValueError("side-parameter channel groups require channel_context.y0") + + def _get_ctx_params( + self, k: int, side_params: Tensor, y_hat_: List[Tensor] + ) -> Tensor: + if k == 0: + return self.channel_context["y0"](side_params) + support = [y_hat_[i] for i in self.support_slices[k]] + if not support: + return self.channel_context[f"y{k}"](side_params) + return self.channel_context[f"y{k}"]( + self.merge_params(side_params, self.merge_y(*support)) + ) + + +# ---------------------------------------------------------------------------- +# Architecture inference helpers (state_dict -> hyperparameters) +# ---------------------------------------------------------------------------- + + +def _group_consecutive(indices: Iterable[int]) -> List[List[int]]: + grouped: List[List[int]] = [] + for index in sorted(indices): + if not grouped or index != grouped[-1][-1] + 1: + grouped.append([index]) + continue + grouped[-1].append(index) + return grouped + + +def _infer_stage_groups(state_dict: Dict[str, Tensor], prefix: str) -> List[List[int]]: + indices = { + int(key.split(".")[1]) + for key in state_dict + if key.startswith(f"{prefix}.") and ".conv1_1.weight" in key + } + return _group_consecutive(indices) + + +def _infer_stage_depths(state_dict: Dict[str, Tensor]) -> Optional[List[int]]: + g_a_groups = _infer_stage_groups(state_dict, "g_a") + g_s_groups = _infer_stage_groups(state_dict, "g_s") + if len(g_a_groups) != 3 or len(g_s_groups) != 3: + return None + return [len(group) for group in g_a_groups + g_s_groups] + + +def _infer_head_dims(state_dict: Dict[str, Tensor], N: int) -> Optional[List[int]]: + head_dims: List[int] = [] + for prefix in ("g_a", "g_s"): + for group in _infer_stage_groups(state_dict, prefix): + if not group: + continue + table_key = ( + f"{prefix}.{group[0]}.trans_block.msa.attn.relative_position_bias_table" + ) + if table_key not in state_dict: + return None + num_heads = state_dict[table_key].size(1) + head_dims.append(N // num_heads) + return head_dims if len(head_dims) == 6 else None + + +def _infer_hyper_head_dim(state_dict: Dict[str, Tensor], N: int, default: int) -> int: + for key in ( + "h_a.1.trans_block.msa.attn.relative_position_bias_table", + "h_mean_s.1.trans_block.msa.attn.relative_position_bias_table", + ): + if key in state_dict: + return N // state_dict[key].size(1) + return default + + +# ---------------------------------------------------------------------------- +# Architecture building blocks +# ---------------------------------------------------------------------------- + + +def _make_mixed_stage( + depth: int, + branch_channels: int, + head_dim: int, + window_size: int, + drop_paths: Sequence[float], + tail: nn.Module, +) -> List[nn.Module]: + if len(drop_paths) != depth: + raise ValueError("drop_paths must match stage depth") + blocks = [ + ConvTransBlock( + branch_channels, + branch_channels, + head_dim, + window_size, + drop_paths[index], + type="W" if index % 2 == 0 else "SW", + ) + for index in range(depth) + ] + return [*blocks, tail] + + +# ---------------------------------------------------------------------------- +# TCM model +# ---------------------------------------------------------------------------- + + +@register_model("lic-tcm") +@register_model("tcm") +class TCM(SimpleVAECompressionModel): + r"""TCM model from J. Liu, H. Sun, J. Katto: `"Learned Image Compression + with Mixed Transformer-CNN Architectures" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2023 (Highlight). + + Stacks parallel Transformer-CNN Mixture (TCM) blocks for the + analysis/synthesis transforms and uses a channel-wise autoregressive + entropy model with parameter-efficient swin-transformer attention + (``SWAtten``). + + The entropy stack is a fully containerised + :class:`HyperpriorLatentCodec` that owns ``h_a``, ``h_s``, the ``z`` + bottleneck and the per-slice side-conditioned channel-groups path. The + channel-context heads route per-slice ``mean_in`` / ``scale_in`` through + independent SWAtten instances before the 3-conv ``mean_cc`` / + ``scale_cc`` stacks (TCM's distinctive widths ``(224, 128)``). + + Args: + N (int): Channel width of the analysis/synthesis transform branches. + M (int): Channels in the latent representation ``y``. + hyper_channels (int): Channels in the hyperprior backbone ``z``. + num_slices (int): Number of channel slices for the entropy model. + max_support_slices (int): Per-slice context cap. + """ + + def __init__( + self, + config: Optional[Sequence[int]] = None, + head_dim: Optional[Sequence[int]] = None, + drop_path_rate: float = 0.0, + N: int = 128, + M: int = 320, + hyper_channels: int = 192, + num_slices: int = 5, + max_support_slices: int = 5, + window_size: int = 8, + hyper_window_size: int = 4, + hyper_head_dim: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + config = tuple(int(value) for value in (config or (2, 2, 2, 2, 2, 2))) + head_dim = tuple(int(value) for value in (head_dim or (8, 16, 32, 32, 16, 8))) + if len(config) != 6: + raise ValueError("config must provide six stage depths") + if len(head_dim) != 6: + raise ValueError("head_dim must provide six stage head dimensions") + if any(value < 0 for value in config): + raise ValueError("config values must be non-negative") + if M % num_slices != 0: + raise ValueError("M must be divisible by num_slices") + if any(N % value != 0 for value in head_dim): + raise ValueError("Each head_dim must divide N") + if N % hyper_head_dim != 0: + raise ValueError("hyper_head_dim must divide N") + + self.config = config + self.head_dim = head_dim + self.window_size = int(window_size) + self.hyper_window_size = int(hyper_window_size) + self.hyper_head_dim = int(hyper_head_dim) + self.N = int(N) + self.M = int(M) + self.hyper_channels = int(hyper_channels) + self.num_slices = int(num_slices) + self.max_support_slices = int(max_support_slices) + + drop_paths = torch.linspace(0, drop_path_rate, sum(config)).tolist() + offset = 0 + + def stage_drop_paths(depth: int) -> List[float]: + nonlocal offset + values = [float(value) for value in drop_paths[offset : offset + depth]] + offset += depth + return values + + self.g_a = nn.Sequential( + ResidualBlockWithStride(3, 2 * N, stride=2), + *_make_mixed_stage( + config[0], + N, + head_dim[0], + self.window_size, + stage_drop_paths(config[0]), + ResidualBlockWithStride(2 * N, 2 * N, stride=2), + ), + *_make_mixed_stage( + config[1], + N, + head_dim[1], + self.window_size, + stage_drop_paths(config[1]), + ResidualBlockWithStride(2 * N, 2 * N, stride=2), + ), + *_make_mixed_stage( + config[2], + N, + head_dim[2], + self.window_size, + stage_drop_paths(config[2]), + conv3x3(2 * N, M, stride=2), + ), + ) + self.g_s = nn.Sequential( + ResidualBlockUpsample(M, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + head_dim[3], + self.window_size, + stage_drop_paths(config[3]), + ResidualBlockUpsample(2 * N, 2 * N, 2), + ), + *_make_mixed_stage( + config[4], + N, + head_dim[4], + self.window_size, + stage_drop_paths(config[4]), + ResidualBlockUpsample(2 * N, 2 * N, 2), + ), + *_make_mixed_stage( + config[5], + N, + head_dim[5], + self.window_size, + stage_drop_paths(config[5]), + subpel_conv3x3(2 * N, 3, 2), + ), + ) + + h_a = nn.Sequential( + ResidualBlockWithStride(M, 2 * N, 2), + *_make_mixed_stage( + config[0], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[0], + conv3x3(2 * N, hyper_channels, stride=2), + ), + ) + h_mean_s = nn.Sequential( + ResidualBlockUpsample(hyper_channels, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[3], + subpel_conv3x3(2 * N, M, 2), + ), + ) + h_scale_s = nn.Sequential( + ResidualBlockUpsample(hyper_channels, 2 * N, 2), + *_make_mixed_stage( + config[3], + N, + self.hyper_head_dim, + self.hyper_window_size, + [0.0] * config[3], + subpel_conv3x3(2 * N, M, 2), + ), + ) + + slice_ch = M // num_slices + widths = (224, 128) + groups = [slice_ch] * num_slices + window_size = self.window_size + + def support_count(k: int) -> int: + return k if max_support_slices < 0 else min(k, max_support_slices) + + def mean_support_ch(k: int) -> int: + return M + slice_ch * support_count(k) + + support_slices = [list(range(support_count(k))) for k in range(num_slices)] + + def swatten_factory(c_in: int, c_out: int) -> nn.Module: + # Independent SWAtten per mean / scale path, mirroring upstream + # atten_mean[k] / atten_scale[k]. + return SWAtten( + input_dim=c_in, + output_dim=c_out, + head_dim=16, + window_size=window_size, + drop_path=0.0, + inter_dim=128, + ) + + # Side-parameter channel-groups wiring, inlined ELIC-style. Differs from + # WACNN/STF only in widths=(224, 128) and the per-slice SWAtten + # support transforms wrapping mean_in / scale_in. + channel_context = { + f"y{k}": MeanScaleContextHead( + mean_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + scale_cc=make_entropy_transform( + mean_support_ch(k), slice_ch, widths=widths + ), + mean_support_transform=swatten_factory( + mean_support_ch(k), mean_support_ch(k) + ), + scale_support_transform=swatten_factory( + mean_support_ch(k), mean_support_ch(k) + ), + side_split=M, + emit_mean_support=True, + ) + for k in range(num_slices) + } + y_latent_codec = { + f"y{k}": _LRPGaussianLatentCodec( + lrp_transform=make_entropy_transform( + mean_support_ch(k) + slice_ch, slice_ch, widths=widths + ), + mean_support_trail_channels=mean_support_ch(k), + quantizer="ste", + ) + for k in range(num_slices) + } + + self.latent_codec = HyperpriorLatentCodec( + h_a=h_a, + h_s=_DualHyperSynthesis(h_mean_s, h_scale_s), + latent_codec={ + "z": EntropyBottleneckLatentCodec( + entropy_bottleneck=EntropyBottleneck(hyper_channels), + quantizer="ste", + ), + "y": _SideContextChannelGroupsLatentCodec( + groups=groups, + channel_context=channel_context, + latent_codec=y_latent_codec, + support_slices=support_slices, + ), + }, + ) + + @classmethod + def from_state_dict(cls, state_dict: Dict[str, Tensor]) -> "TCM": + N = state_dict["g_a.0.conv1.weight"].size(0) // 2 + M = state_dict["latent_codec.h_a.0.conv1.weight"].size(1) + config = _infer_stage_depths(state_dict) or [2, 2, 2, 2, 2, 2] + head_dim = _infer_head_dims(state_dict, N) or [8, 16, 32, 32, 16, 8] + hyper_channels = state_dict["latent_codec.z.entropy_bottleneck.quantiles"].size( + 0 + ) + num_slices = infer_num_slices(state_dict) or 5 + max_support_slices = infer_max_support_slices(state_dict, M, num_slices) + net = cls( + config=config, + head_dim=head_dim, + N=N, + M=M, + hyper_channels=hyper_channels, + num_slices=num_slices, + max_support_slices=max_support_slices, + hyper_head_dim=_infer_hyper_head_dim(state_dict, N, 32), + ) + # ConvTransBlock's WindowAttention registers + # ``relative_position_index`` as a non-persistent buffer, so it is + # absent from saved state dicts. Tolerate the missing keys. + incompatible_keys = net.load_state_dict(state_dict, strict=False) + allowed_missing = { + key for key in net.state_dict() if key.endswith("relative_position_index") + } + missing_keys = set(incompatible_keys.missing_keys) - allowed_missing + if missing_keys or incompatible_keys.unexpected_keys: + raise RuntimeError( + "Unexpected incompatibility while loading TCM state_dict: " + f"missing={sorted(missing_keys)}, " + f"unexpected={sorted(incompatible_keys.unexpected_keys)}" + ) + return net diff --git a/compressai/zoo/__init__.py b/compressai/zoo/__init__.py index acebc705..e3e75863 100644 --- a/compressai/zoo/__init__.py +++ b/compressai/zoo/__init__.py @@ -31,12 +31,14 @@ bmshj2018_factorized, bmshj2018_factorized_relu, bmshj2018_hyperprior, + cca, cheng2020_anchor, cheng2020_attn, mbt2018, mbt2018_mean, stf, stf_wacnn, + tcm, ) from .image_vbr import bmshj2018_hyperprior_vbr, mbt2018_mean_vbr, mbt2018_vbr from .pretrained import load_pretrained as load_state_dict @@ -52,6 +54,8 @@ "cheng2020-attn": cheng2020_attn, "stf": stf, "stf-wacnn": stf_wacnn, + "tcm": tcm, + "cca": cca, "bmshj2018-hyperprior-vbr": bmshj2018_hyperprior_vbr, "mbt2018-mean-vbr": mbt2018_mean_vbr, "mbt2018-vbr": mbt2018_vbr, diff --git a/compressai/zoo/image.py b/compressai/zoo/image.py index f506d6bf..69085976 100644 --- a/compressai/zoo/image.py +++ b/compressai/zoo/image.py @@ -84,6 +84,8 @@ def __getattr__(self, item): "cheng2020_attn", "stf", "stf_wacnn", + "tcm", + "cca", ] model_architectures = { @@ -97,6 +99,8 @@ def __getattr__(self, item): # Resolved lazily so `compressai.zoo` is importable without `timm`. "stf": _LazyImport("compressai.models.stf", "SymmetricalTransFormer"), "stf-wacnn": _LazyImport("compressai.models.stf", "WACNN"), + "tcm": _LazyImport("compressai.models.tcm", "TCM"), + "cca": _LazyImport("compressai.models.cca", "CCAModel"), } root_url = "https://compressai.s3.amazonaws.com/models/v1" @@ -525,3 +529,43 @@ def stf_wacnn(pretrained: bool = False, progress: bool = True, **kwargs): from compressai.models.stf import WACNN return WACNN(**kwargs) + + +def tcm(pretrained: bool = False, progress: bool = True, **kwargs): + r"""TCM (Transformer-CNN Mixture) model from J. Liu, H. Sun, J. Katto: + `"Learned Image Compression with Mixed Transformer-CNN Architectures" + `_, IEEE/CVF Conf. on Computer Vision + and Pattern Recognition (CVPR), 2023. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained TCM weights are not yet hosted on S3.") + from compressai.models.tcm import TCM + + return TCM(**kwargs) + + +def cca(pretrained: bool = False, progress: bool = True, **kwargs): + r"""CCA (Causal Context Adjustment) model from M. Han, S. Jiang, S. Li, + X. Deng, M. Xu, C. Zhu, S. Liu: `"Causal Context Adjustment Loss for + Learned Image Compression" `_, NeurIPS + 2024. + + Args: + pretrained (bool): If True, returns a pre-trained model. Currently + unavailable; raises ``RuntimeError``. + progress (bool): If True, displays a progress bar of the download to + stderr. + """ + del progress + if pretrained: + raise RuntimeError("Pre-trained CCA weights are not yet hosted on S3.") + from compressai.models.cca import CCAModel + + return CCAModel(**kwargs) diff --git a/examples/convert_cca_checkpoint.py b/examples/convert_cca_checkpoint.py new file mode 100644 index 00000000..727df227 --- /dev/null +++ b/examples/convert_cca_checkpoint.py @@ -0,0 +1,474 @@ +"""Convert an upstream CCA checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. +``checkpoint_lambda_0.3.pth.tar`` from M. Han et al., +https://github.com/CVL-UESTC/CCA, NeurIPS 2024), translates it to +compressai's module layout, and writes a state dict that +``compressai.models.cca.CCAModel.from_state_dict`` can load directly. +Optionally reports forward-pass sanity numbers (PSNR / bpp) on a +synthetic input. + +The upstream-vs-compressai key differences (NAFBlock interior renames +``dwconv`` / ``sca`` / ``FFN`` / ``conv1``, NAFTransform ``in_conv`` / +``out_conv``, ``mean_NAF_transforms.{k}`` -> +``channel_context.y{k}.mean_support_transform``, ``mean_cc_transforms.{k}`` +-> ``channel_context.y{k}.mean_cc``, ``lrp_transforms.{k}`` -> +``latent_codec.y{k}.lrp_transform``, ``aux_entropymodel.*`` -> +``aux_entropy_model.inner_codec.*``, the ``gaussian_conditional`` +replication across slices, the H+G containerised re-rooting under +``latent_codec.*``, etc.) are all handled inside +``convert_upstream_cca_state_dict``; this script is a thin CLI around it. + +Example:: + + python examples/convert_cca_checkpoint.py \\ + --src candidate/CCA/checkpoint_lambda_0.3.pth.tar \\ + --dst /tmp/cca_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse + +from pathlib import Path +from typing import Dict, List, Optional, Sequence + +import torch + +from torch import Tensor + +from compressai.models.cca import CCAModel + +# ---------------------------------------------------------------------------- +# Upstream → compressai state-dict conversion. +# +# Lives here (not in compressai/models/cca.py) so the model module stays a +# clean compressai-native definition — ``CCAModel.from_state_dict`` only +# loads already-converted state dicts. Run this script once to translate a +# published upstream checkpoint into compressai layout, then load the result +# via ``from_state_dict``. +# ---------------------------------------------------------------------------- + + +# NAFBlock interior renames (upstream -> compressai). These are scoped to +# detected NAFBlock prefixes so they don't accidentally rewrite ``conv1`` in +# unrelated modules (e.g. ResidualBottleneckBlock has its own ``conv1``). +_NAF_BLOCK_RENAMES = { + "dwconv.": "pointwise_depthwise.", + "sca.": "channel_attention.", + "FFN.": "feed_forward.", + "conv1.": "project.", +} +# NAFTransform interior renames. +_NAF_TRANSFORM_RENAMES = { + "in_conv.": "input_projection.", + "out_conv.": "output_projection.", +} +# Top-level rename map applied AFTER NAFBlock / NAFTransform interior renames +# and BEFORE per-slice rerooting. Used for hyperprior backbone and aux module. +_TOPLEVEL_RENAMES: Dict[str, str] = { + "aux_entropymodel.": "aux_entropy_model.", + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "z_entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} +# Upstream uses ``mean_NAF_transforms`` / ``scale_NAF_transforms``; compressai +# stores them at ``{mean,scale}_support_transform`` inside the channel-context +# head (singular per slice). Aliasing here keeps the per-slice rerooting pass +# uniform across main and aux branches. +_NAMED_PART_RENAMES: Dict[str, str] = { + "mean_NAF_transforms.": "mean_support_transforms.", + "scale_NAF_transforms.": "scale_support_transforms.", +} + + +def _is_upstream_cca_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic detector for upstream ``LICAutoencoder`` checkpoints.""" + for key in state_dict: + if ( + key.startswith("mean_NAF_transforms.") + or key.startswith("scale_NAF_transforms.") + or key.startswith("aux_entropymodel.") + or key.startswith("z_entropy_bottleneck.") + or key.startswith("mean_cc_transforms.") + or key.startswith("scale_cc_transforms.") + or key.startswith("lrp_transforms.") + ): + return True + return False + + +def _find_naf_block_prefixes(state_dict: Dict[str, Tensor]) -> List[str]: + """Locate every NAFBlock instance by matching the ``.beta`` / ``.gamma`` + / ``.dwconv.0.weight`` / ``.FFN.0.weight`` 4-tuple at the same scope. + """ + suffix = ".beta" + out: List[str] = [] + for key in state_dict: + if not key.endswith(suffix): + continue + base = key[: -len(suffix)] + if ( + f"{base}.gamma" in state_dict + and f"{base}.dwconv.0.weight" in state_dict + and f"{base}.FFN.0.weight" in state_dict + ): + out.append(base) + return out + + +def _find_naf_transform_prefixes(state_dict: Dict[str, Tensor]) -> List[str]: + """Locate every NAFTransform instance by matching the ``.in_conv.weight`` + / ``.out_conv.weight`` / ``.blocks.0.beta`` triple at the same scope. + """ + suffix = ".in_conv.weight" + out: List[str] = [] + for key in state_dict: + if not key.endswith(suffix): + continue + base = key[: -len(suffix)] + if ( + f"{base}.out_conv.weight" in state_dict + and f"{base}.blocks.0.beta" in state_dict + ): + out.append(base) + return out + + +def _strip_prefix(key: str, prefix: str) -> Optional[str]: + return key[len(prefix) :] if key.startswith(prefix) else None + + +def _rename_with_table( + key: str, + base_prefixes: Sequence[str], + rename_map: Dict[str, str], +) -> str: + for base in base_prefixes: + head = base + "." + rest = _strip_prefix(key, head) + if rest is None: + continue + for old, new in rename_map.items(): + inner = _strip_prefix(rest, old) + if inner is not None: + return head + new + inner + return key + return key + + +def _reroot_per_slice_keys( + cleaned: Dict[str, Tensor], + converted: Dict[str, Tensor], + *, + legacy_prefix: str, + container_prefix: str, + sub_name: str, + num_slices: int, + consume: List[str], +) -> None: + """Move ``legacy_prefix.{k}.<...>`` keys to + ``container_prefix.y{k}.sub_name.<...>``. + + Keys that match are removed from ``cleaned`` (recorded in ``consume`` + for a later bulk drop) and inserted into ``converted`` under the new + path. + """ + for key in list(cleaned.keys()): + rest = _strip_prefix(key, legacy_prefix + ".") + if rest is None: + continue + idx_str, _, tail = rest.partition(".") + try: + idx = int(idx_str) + except ValueError: + continue + if idx >= num_slices: + continue + new_key = ( + f"{container_prefix}.y{idx}.{sub_name}.{tail}" + if tail + else f"{container_prefix}.y{idx}.{sub_name}" + ) + converted[new_key] = cleaned[key] + consume.append(key) + + +def _replicate_gaussian_conditional( + cleaned: Dict[str, Tensor], + converted: Dict[str, Tensor], + *, + legacy_prefix: str, + new_prefix: str, + num_slices: int, + consume: List[str], +) -> None: + """Copy a single shared ``gaussian_conditional.<...>`` buffer set under + every per-slice leaf so the per-slice + :class:`GaussianConditionalLatentCodec` copies all strict-load. + """ + for key in list(cleaned.keys()): + tail = _strip_prefix(key, legacy_prefix + ".") + if tail is None: + continue + for k in range(num_slices): + new_key = f"{new_prefix}.y{k}.gaussian_conditional.{tail}" + converted[new_key] = cleaned[key] + consume.append(key) + + +def convert_upstream_cca_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Translate an upstream CCA ``LICAutoencoder`` state dict to the + compressai layout produced by :class:`CCAModel`. + + Conversion runs three logical passes: + + 1. Interior renames: ``NAFBlock`` (``dwconv`` → ``pointwise_depthwise``, + etc.) and ``NAFTransform`` (``in_conv`` → ``input_projection``, + etc.). Detection is by structural fingerprint + (:func:`_find_naf_block_prefixes`) so the renames apply uniformly to + NAFBlocks anywhere in the state dict (``g_a`` / ``g_s`` / per-slice + support transforms / aux module). + 2. Top-level renames: ``aux_entropymodel`` → ``aux_entropy_model``, + hyperprior backbone (``h_a`` / ``h_mean_s`` / ``h_scale_s``) and + ``z_entropy_bottleneck`` are moved under ``latent_codec.*``; + ``mean_NAF_transforms`` / ``scale_NAF_transforms`` are aliased to + the singular ``{mean,scale}_support_transforms`` form so the + per-slice rerooting in pass 3 only handles one name. + 3. Per-slice rerooting: ``mean_cc_transforms.{k}`` / + ``scale_cc_transforms.{k}`` move to + ``latent_codec.y.channel_context.y{k}.{mean,scale}_cc.*``; + ``mean_support_transforms.{k}`` / ``scale_support_transforms.{k}`` + move to + ``latent_codec.y.channel_context.y{k}.{mean,scale}_support_transform.*``; + ``lrp_transforms.{k}`` moves to + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; the single + shared ``gaussian_conditional.*`` buffer set is replicated under + every per-slice leaf + (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``). The + same rerooting is applied to ``aux_entropy_model.*`` (after the + top-level rename) under ``aux_entropy_model.inner_codec.*``. + + The returned dict can be loaded by :meth:`CCAModel.from_state_dict`. + """ + naf_blocks = _find_naf_block_prefixes(state_dict) + naf_transforms = _find_naf_transform_prefixes(state_dict) + + # Pass 1+2: interior + top-level renames. + cleaned: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = _rename_with_table(key, naf_blocks, _NAF_BLOCK_RENAMES) + new_key = _rename_with_table(new_key, naf_transforms, _NAF_TRANSFORM_RENAMES) + for old, new in _NAMED_PART_RENAMES.items(): + new_key = new_key.replace(old, new) + for old, new in _TOPLEVEL_RENAMES.items(): + if new_key.startswith(old): + new_key = new + new_key[len(old) :] + break + cleaned[new_key] = value + + # Pass 3a: per-slice rerooting for the main entropy stack. Discover + # ``num_slices`` from ``mean_cc_transforms`` first, then drive the rest. + main_indices = sorted( + { + int(key[len("mean_cc_transforms.") :].split(".", 1)[0]) + for key in cleaned + if key.startswith("mean_cc_transforms.") + } + ) + num_slices_main = len(main_indices) + + converted: Dict[str, Tensor] = {} + consumed: List[str] = [] + + if num_slices_main: + for legacy, container, sub in ( + ("mean_cc_transforms", "latent_codec.y.channel_context", "mean_cc"), + ("scale_cc_transforms", "latent_codec.y.channel_context", "scale_cc"), + ( + "mean_support_transforms", + "latent_codec.y.channel_context", + "mean_support_transform", + ), + ( + "scale_support_transforms", + "latent_codec.y.channel_context", + "scale_support_transform", + ), + ("lrp_transforms", "latent_codec.y.latent_codec", "lrp_transform"), + ): + _reroot_per_slice_keys( + cleaned, + converted, + legacy_prefix=legacy, + container_prefix=container, + sub_name=sub, + num_slices=num_slices_main, + consume=consumed, + ) + _replicate_gaussian_conditional( + cleaned, + converted, + legacy_prefix="gaussian_conditional", + new_prefix="latent_codec.y.latent_codec", + num_slices=num_slices_main, + consume=consumed, + ) + + # Pass 3b: per-slice rerooting inside the aux entropy module. Discover + # ``num_slices_aux`` from ``aux_entropy_model.mean_cc_transforms``. + aux_indices = sorted( + { + int(key[len("aux_entropy_model.mean_cc_transforms.") :].split(".", 1)[0]) + for key in cleaned + if key.startswith("aux_entropy_model.mean_cc_transforms.") + } + ) + num_slices_aux = len(aux_indices) + if num_slices_aux: + for legacy, container, sub in ( + ( + "aux_entropy_model.mean_cc_transforms", + "aux_entropy_model.inner_codec.channel_context", + "mean_cc", + ), + ( + "aux_entropy_model.scale_cc_transforms", + "aux_entropy_model.inner_codec.channel_context", + "scale_cc", + ), + ( + "aux_entropy_model.mean_support_transforms", + "aux_entropy_model.inner_codec.channel_context", + "mean_support_transform", + ), + ( + "aux_entropy_model.scale_support_transforms", + "aux_entropy_model.inner_codec.channel_context", + "scale_support_transform", + ), + ( + "aux_entropy_model.lrp_transforms", + "aux_entropy_model.inner_codec.latent_codec", + "lrp_transform", + ), + ): + _reroot_per_slice_keys( + cleaned, + converted, + legacy_prefix=legacy, + container_prefix=container, + sub_name=sub, + num_slices=num_slices_aux, + consume=consumed, + ) + _replicate_gaussian_conditional( + cleaned, + converted, + legacy_prefix="aux_entropy_model.gaussian_conditional", + new_prefix="aux_entropy_model.inner_codec.latent_codec", + num_slices=num_slices_aux, + consume=consumed, + ) + + for key in consumed: + cleaned.pop(key, None) + # Remaining keys (g_a / g_s / latent_codec.* hyperprior backbone / + # aux_entropy_model.y_entropy_bottleneck / etc.) pass through unchanged. + converted.update(cleaned) + return converted + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream CCA checkpoint (e.g. checkpoint_lambda_0.3.pth.tar).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + if _is_upstream_cca_state_dict(upstream): + converted = convert_upstream_cca_state_dict(upstream) + else: + converted = upstream + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = CCAModel.from_state_dict(converted) + net.eval() + print( + "variant: " + f"M={net.M}, N={net.N}, slice_sizes={tuple(net.slice_sizes)}, " + f"em_hidden={net.em_hidden_channels}, em_layers={net.em_num_layers}, " + f"cca_training={net.cca_training}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/convert_stf_checkpoint.py b/examples/convert_stf_checkpoint.py index 914857be..e770dc76 100644 --- a/examples/convert_stf_checkpoint.py +++ b/examples/convert_stf_checkpoint.py @@ -24,20 +24,188 @@ from __future__ import annotations import argparse +import re from pathlib import Path +from typing import Dict import torch -from compressai.models.stf import ( - WACNN, - SymmetricalTransFormer, - convert_upstream_stf_state_dict, -) +from torch import Tensor + +from compressai.models.stf import WACNN, SymmetricalTransFormer _ARCHES = {"stf": SymmetricalTransFormer, "wacnn": WACNN} +# --------------------------------------------------------------------------- +# Upstream STF / WACNN checkpoint conversion +# +# Lives here (not in compressai/models/stf.py) so the model module stays a +# clean compressai-native definition — `WACNN` / `SymmetricalTransFormer` +# `.from_state_dict` only load already-converted state dicts. Run this +# script once to translate a published upstream checkpoint into compressai +# layout, then load the result via `from_state_dict`. +# --------------------------------------------------------------------------- + +_UPSTREAM_LATENT_CODEC_PREFIXES = ( + "cc_mean_transforms", + "cc_scale_transforms", + "lrp_transforms", + "gaussian_conditional", +) + +# Top-level rename map applied AFTER per-slice cc_/lrp_/gaussian_conditional +# rerooting. Keys are matched as exact prefixes (with the trailing dot). +_UPSTREAM_TOP_LEVEL_RENAMES: Dict[str, str] = { + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} + +# Upstream STF places the WindowAttention parameters directly under +# ``conv_b..attn.{qkv,proj,relative_position_*}``. CompressAI wraps the +# WindowAttention inside a :class:`compressai.layers.attn.swin.WMSA` shim, so +# the live model keeps ``WMSA.attn = WindowAttention(...)`` and the +# parameters land at ``conv_b..attn.attn.*``. This regex inserts the extra +# ``.attn`` so renamed upstream keys round-trip into the WMSA wrapper without +# changing the model topology. +_WMSA_NEST_PATTERN = re.compile( + r"(\.conv_b\.\d+\.attn)\.(qkv\.|proj\.|relative_position_)" +) + + +def _nest_winmsa_keys(key: str) -> str: + """Insert the WMSA wrapper level (``.attn``) into upstream + ``conv_b.*.attn.{qkv,proj,relative_position_*}`` keys.""" + return _WMSA_NEST_PATTERN.sub(r"\1.attn.\2", key) + + +def _is_upstream_stf_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream checkpoints either carry a ``module.`` prefix or + place ``cc_mean_transforms`` at the root instead of under ``latent_codec``. + """ + for key in state_dict: + if key.startswith("module."): + return True + head = key.split(".", 1)[0] + if head in _UPSTREAM_LATENT_CODEC_PREFIXES or head in { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + }: + return True + return False + + +def convert_upstream_stf_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Translate a candidate ``STF`` / ``WACNN`` state dict into compressai layout. + + Upstream checkpoints (``stf__best.pth.tar`` / ``cnn__best.pth.tar`` + from `Zou et al. 2022 `_) are saved from a + ``DataParallel``-wrapped module and place the channel-conditional entropy + transforms at the model root. After the H+G containerised refactor + compressai houses those transforms (plus the Gaussian conditional and + the hyperprior backbone) inside ``latent_codec.*``. This helper: + + - strips the leading ``module.`` prefix added by ``DataParallel``; + - re-roots ``cc_mean_transforms.{k}`` / ``cc_scale_transforms.{k}`` / + ``lrp_transforms.{k}`` under + ``latent_codec.y.channel_context.y{k}.{mean_cc,scale_cc}.*`` / + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; + - replicates the single shared ``gaussian_conditional.*`` buffer set + under each per-slice leaf (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``); + - moves ``entropy_bottleneck.*`` / ``h_a.*`` / ``h_mean_s.*`` / + ``h_scale_s.*`` under ``latent_codec.*`` per the new layout; + - leaves ``g_a`` / ``g_s`` / ``patch_embed`` / ``layers`` / + ``syn_layers`` / ``end_conv`` keys unchanged. + + The wiring sets ``emit_mean_support=True`` on the ``MeanScaleContextHead`` + so the upstream LRP layout (``cat(latent_means, *prev_y_hat, y_hat)``) is + recoverable inside the leaf — upstream ``lrp_transforms.{k}`` weights + transfer byte-for-byte. The model's ``WinNoShiftAttention`` consumers wrap + their windowed-attention layers in a :class:`WMSA` shim, so the conversion + also nests upstream ``conv_b.{i}.attn.{qkv,proj,relative_position_*}`` keys + under the extra ``.attn`` level (see :func:`_nest_winmsa_keys`). + + The returned dict can be loaded directly by ``WACNN.from_state_dict`` / + ``SymmetricalTransFormer.from_state_dict``. + """ + converted: Dict[str, Tensor] = {} + + _LEGACY_ROOT_HEADS = set(_UPSTREAM_LATENT_CODEC_PREFIXES) | { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + } + + # Pass 1: strip module. prefix, fold the upstream single-``attn`` window + # attention path back into compressai's WMSA wrapper layout, and inventory + # which keys exist. + cleaned: Dict[str, Tensor] = {} + has_legacy_root_keys = False + for key, value in state_dict.items(): + new_key = key[len("module.") :] if key.startswith("module.") else key + new_key = _nest_winmsa_keys(new_key) + cleaned[new_key] = value + if new_key.split(".", 1)[0] in _LEGACY_ROOT_HEADS: + has_legacy_root_keys = True + + if not has_legacy_root_keys: + # Already in (or near) the new layout — return cleaned dict as-is. + return cleaned + + # Pass 2: discover slice indices to drive gaussian_conditional replication + # and per-slice rerooting. + slice_indices = sorted( + { + int(key.split(".")[1]) + for key in cleaned + if key.startswith("cc_mean_transforms.") + } + ) + num_slices = len(slice_indices) + + for key, value in cleaned.items(): + head = key.split(".", 1)[0] + if head == "cc_mean_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.mean_cc." + ".".join(rest) + converted[new_key] = value + elif head == "cc_scale_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.scale_cc." + ".".join(rest) + converted[new_key] = value + elif head == "lrp_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.latent_codec.y{k}.lrp_transform." + ".".join( + rest + ) + converted[new_key] = value + elif head == "gaussian_conditional": + # Replicate the single shared instance to per-slice leaves. + tail = key[len("gaussian_conditional.") :] + for k in range(num_slices): + new_key = ( + f"latent_codec.y.latent_codec.y{k}" f".gaussian_conditional.{tail}" + ) + converted[new_key] = value + else: + renamed = key + for prefix, replacement in _UPSTREAM_TOP_LEVEL_RENAMES.items(): + if key.startswith(prefix): + renamed = replacement + key[len(prefix) :] + break + converted[renamed] = value + + return converted + + def _detect_arch(state_dict: dict) -> str: keys = state_dict.keys() if any("patch_embed" in k for k in keys): @@ -92,7 +260,7 @@ def main() -> None: arch = args.arch or _detect_arch(upstream) cls = _ARCHES[arch] - net = cls.from_state_dict(upstream) + net = cls.from_state_dict(converted) net.eval() print(f"loaded {arch.upper()}: {sum(p.numel() for p in net.parameters()):,} params") diff --git a/examples/convert_tcm_checkpoint.py b/examples/convert_tcm_checkpoint.py new file mode 100644 index 00000000..d8ad9771 --- /dev/null +++ b/examples/convert_tcm_checkpoint.py @@ -0,0 +1,379 @@ +"""Convert an upstream LIC-TCM checkpoint to compressai layout. + +Loads the published candidate weight file (e.g. ``0.05.pth.tar`` or +``mse_lambda_0.05.pth.tar`` from the LIC_TCM repo, +https://github.com/jmliu206/LIC_TCM), translates it to compressai's module +layout, and writes a state dict that ``compressai.models.tcm.TCM.from_state_dict`` +can load directly. Optionally reports forward-pass sanity numbers +(PSNR / bpp) on a synthetic input. + +The upstream-vs-compressai key differences (``module.`` ``DataParallel`` +prefix, the ``nn.Sequential`` wrapper around each ``SWAtten``, +``atten_mean`` -> ``latent_codec.y.channel_context.y{k}.mean_support_transform``, +ConvTransBlock MSA buffer layouts, layer-norm names, the H+G containerised +re-rooting under ``latent_codec.*``, etc.) are all handled inside +``convert_upstream_tcm_state_dict``; this script is a thin CLI around it. + +Example:: + + python examples/convert_tcm_checkpoint.py \\ + --src candidate/TCM/0.05.pth.tar \\ + --dst /tmp/tcm_compressai.pth \\ + --smoke +""" + +from __future__ import annotations + +import argparse +import re + +from pathlib import Path +from typing import Dict, Tuple + +import torch + +from torch import Tensor + +from compressai.models.tcm import TCM + +# ---------------------------------------------------------------------------- +# Upstream LIC_TCM checkpoint conversion +# ---------------------------------------------------------------------------- + + +# Heads from upstream LIC_TCM (Liu et al. 2023) checkpoints that move under +# ``latent_codec.*`` after the H+G containerised refactor. +_UPSTREAM_LATENT_CODEC_PREFIXES = ( + "cc_mean_transforms", + "cc_scale_transforms", + "lrp_transforms", + "atten_mean", + "atten_scale", + "mean_support_transforms", + "scale_support_transforms", + "gaussian_conditional", +) + +# Top-level rename map applied AFTER per-slice rerooting. Keys are matched as +# exact prefixes (with the trailing dot). +_UPSTREAM_TOP_LEVEL_RENAMES: Dict[str, str] = { + "h_a.": "latent_codec.h_a.", + "h_mean_s.": "latent_codec.h_s.h_mean_s.", + "h_scale_s.": "latent_codec.h_s.h_scale_s.", + "entropy_bottleneck.": "latent_codec.z.entropy_bottleneck.", +} + +# Upstream LIC_TCM wraps each ``SWAtten`` in an ``nn.Sequential`` and stores +# parameters at ``atten_mean.{k}.0.<...>``. Compressai's :class:`SWAtten` +# lives directly at ``mean_support_transform.<...>`` after rerooting, so the +# leading ``.0`` wrapper level is stripped. +_UPSTREAM_SWATTEN_WRAPPER = re.compile( + r"^(atten_mean|atten_scale|mean_support_transforms|scale_support_transforms)\.(\d+)\.0\." +) + + +def _rename_msa_keys(key: str, value: Tensor) -> Tuple[str, Tensor]: + """Translate upstream LIC_TCM ConvTransBlock-internal MSA layout to + compressai's :class:`WMSA` wrapper layout. + + Three kinds of upstream keys appear inside ``g_a`` / ``g_s`` / ``h_a`` / + ``h_mean_s`` / ``h_scale_s`` blocks: + + - ``.msa.relative_position_params`` is a ``(2*win-1, 2*win-1, num_heads)`` + buffer; compressai's ``WindowAttention`` registers it as a flat + ``(N, num_heads)`` ``relative_position_bias_table``. The value is + permuted and reshaped accordingly. + - ``.msa.embedding_layer`` is upstream's name for the fused ``qkv`` + linear; compressai exposes it as ``.msa.attn.qkv.<...>``. + - ``.msa.linear`` is upstream's optional output projection; compressai + drops it and instead uses the WindowAttention's identity ``.proj`` — + see :func:`_ensure_identity_attention_projection` for the identity + injection that keeps strict ``load_state_dict`` round-trips clean. + """ + if ".msa.relative_position_params" in key: + new_key = key.replace( + ".msa.relative_position_params", + ".msa.attn.relative_position_bias_table", + ) + new_value = value.permute(1, 2, 0).reshape(-1, value.size(0)).contiguous() + return new_key, new_value + if ".msa.embedding_layer." in key: + return key.replace(".msa.embedding_layer.", ".msa.attn.qkv."), value + if ".msa.linear." in key: + return key.replace(".msa.linear.", ".msa.output_proj."), value + return key, value + + +def _ensure_identity_attention_projection( + state_dict: Dict[str, Tensor], + output_proj_key: str, + output_proj_value: Tensor, +) -> None: + """Inject an identity ``WindowAttention.proj`` for upstream blocks whose + output projection sits outside the attention module (``.msa.linear`` → + ``.msa.output_proj``). The model has both ``.msa.attn.proj`` (inside + WindowAttention, identity-initialised here) and ``.msa.output_proj`` + (the actual learned projection) so strict ``load_state_dict`` succeeds. + """ + prefix, suffix = output_proj_key.rsplit(".msa.output_proj.", 1) + attn_proj_key = f"{prefix}.msa.attn.proj.{suffix}" + if attn_proj_key in state_dict: + return + if suffix == "weight": + dimension = output_proj_value.size(0) + state_dict[attn_proj_key] = torch.eye( + dimension, + dtype=output_proj_value.dtype, + device=output_proj_value.device, + ) + return + if suffix == "bias": + state_dict[attn_proj_key] = torch.zeros_like(output_proj_value) + + +def _is_upstream_tcm_state_dict(state_dict: Dict[str, Tensor]) -> bool: + """Heuristic: upstream LIC_TCM checkpoints either carry the ``module.`` + prefix from ``DataParallel`` saving, the ``.msa.relative_position_params`` + buffer, or the per-slice entropy heads (``cc_mean_transforms`` / + ``atten_mean`` / ``lrp_transforms`` / ``gaussian_conditional`` / ``h_a`` + / ``h_mean_s`` / ``h_scale_s`` / ``entropy_bottleneck``) at the model + root rather than under ``latent_codec.*``. + """ + legacy_roots = set(_UPSTREAM_LATENT_CODEC_PREFIXES) | { + "h_a", + "h_mean_s", + "h_scale_s", + "entropy_bottleneck", + } + for key in state_dict: + if key.startswith("module."): + return True + if ( + ".msa.relative_position_params" in key + or ".msa.embedding_layer." in key + or ".msa.linear." in key + ): + return True + if key.split(".", 1)[0] in legacy_roots: + return True + return False + + +def convert_upstream_tcm_state_dict( + state_dict: Dict[str, Tensor], +) -> Dict[str, Tensor]: + """Translate an upstream LIC_TCM state dict into compressai layout. + + Upstream checkpoints (e.g. ``0.013.pth..tar`` from + `Liu et al. 2023 `_, + https://github.com/jmliu206/LIC_TCM) place the channel-conditional entropy + transforms and the hyperprior backbone at the model root. After the H+G + containerised refactor compressai houses those transforms (plus the + Gaussian conditional and the ``z`` bottleneck) inside ``latent_codec.*``. + This helper: + + - strips the leading ``module.`` prefix added by ``DataParallel``; + - rewrites ConvTransBlock attention buffers via :func:`_rename_msa_keys` + (``.msa.relative_position_params`` / ``.msa.embedding_layer`` / + ``.msa.linear``) and standard layer-name renames (``ln1`` → ``norm1``, + ``mlp.0`` / ``mlp.2`` → ``mlp.fc1`` / ``mlp.fc2``); + - unwraps the upstream ``nn.Sequential`` wrapper around each ``SWAtten`` + (``atten_mean.{k}.0.<...>`` → ``atten_mean.{k}.<...>``); + - re-roots ``cc_mean_transforms.{k}`` / ``cc_scale_transforms.{k}`` / + ``lrp_transforms.{k}`` under + ``latent_codec.y.channel_context.y{k}.{mean_cc,scale_cc}.*`` / + ``latent_codec.y.latent_codec.y{k}.lrp_transform.*``; + - re-roots ``atten_mean.{k}`` / ``atten_scale.{k}`` (or their + ``mean_support_transforms`` / ``scale_support_transforms`` aliases) + under ``latent_codec.y.channel_context.y{k}.{mean,scale}_support_transform.*``; + - replicates the single shared ``gaussian_conditional.*`` buffer set + under each per-slice leaf + (``latent_codec.y.latent_codec.y{k}.gaussian_conditional.*``); + - moves ``entropy_bottleneck.*`` / ``h_a.*`` / ``h_mean_s.*`` / + ``h_scale_s.*`` under ``latent_codec.*`` per the new layout; + - leaves ``g_a`` / ``g_s`` keys (other than the MSA renames inside their + ConvTransBlocks) untouched. + + The wiring sets ``emit_mean_support=True`` on the + :class:`MeanScaleContextHead`, so the upstream LRP layout + (``cat(latent_means, *prev_y_hat, y_hat)``) is recoverable inside the + leaf — upstream ``lrp_transforms.{k}`` weights therefore transfer + byte-for-byte. + + The returned dict can be loaded by :meth:`TCM.from_state_dict`, which + auto-detects the upstream layout and calls this helper, so direct + invocation is only needed when persisting the converted dict. + """ + # Pass 1: strip ``module.`` prefix; rewrite ConvTransBlock attention + # buffers and layer names; unwrap the SWAtten ``nn.Sequential`` wrapper; + # alias ``atten_mean`` / ``atten_scale`` to the canonical + # ``mean_support_transforms`` / ``scale_support_transforms`` names so the + # per-slice rerooting in Pass 2 only has to handle one form. + cleaned: Dict[str, Tensor] = {} + for key, value in state_dict.items(): + new_key = key[len("module.") :] if key.startswith("module.") else key + new_key, value = _rename_msa_keys(new_key, value) + wrapper = _UPSTREAM_SWATTEN_WRAPPER.match(new_key) + if wrapper: + new_key = ( + f"{wrapper.group(1)}.{wrapper.group(2)}." + new_key[wrapper.end() :] + ) + if new_key.startswith("atten_mean."): + new_key = "mean_support_transforms." + new_key[len("atten_mean.") :] + elif new_key.startswith("atten_scale."): + new_key = "scale_support_transforms." + new_key[len("atten_scale.") :] + new_key = new_key.replace(".ln1.", ".norm1.") + new_key = new_key.replace(".ln2.", ".norm2.") + new_key = new_key.replace(".mlp.0.", ".mlp.fc1.") + new_key = new_key.replace(".mlp.2.", ".mlp.fc2.") + if ".msa.output_proj." in new_key: + _ensure_identity_attention_projection(cleaned, new_key, value) + cleaned[new_key] = value + + # Pass 2: discover slice indices to drive ``gaussian_conditional`` + # replication, then reroot per-slice and top-level keys. + converted: Dict[str, Tensor] = {} + slice_indices = sorted( + { + int(key.split(".")[1]) + for key in cleaned + if key.startswith("cc_mean_transforms.") + } + ) + num_slices = len(slice_indices) + + for key, value in cleaned.items(): + head = key.split(".", 1)[0] + if head == "cc_mean_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.mean_cc." + ".".join(rest) + converted[new_key] = value + elif head == "cc_scale_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.channel_context.y{k}.scale_cc." + ".".join(rest) + converted[new_key] = value + elif head == "mean_support_transforms": + _, k, *rest = key.split(".") + new_key = ( + f"latent_codec.y.channel_context.y{k}.mean_support_transform." + + ".".join(rest) + ) + converted[new_key] = value + elif head == "scale_support_transforms": + _, k, *rest = key.split(".") + new_key = ( + f"latent_codec.y.channel_context.y{k}.scale_support_transform." + + ".".join(rest) + ) + converted[new_key] = value + elif head == "lrp_transforms": + _, k, *rest = key.split(".") + new_key = f"latent_codec.y.latent_codec.y{k}.lrp_transform." + ".".join( + rest + ) + converted[new_key] = value + elif head == "gaussian_conditional": + tail = key[len("gaussian_conditional.") :] + for k in range(num_slices): + new_key = ( + f"latent_codec.y.latent_codec.y{k}.gaussian_conditional.{tail}" + ) + converted[new_key] = value + else: + renamed = key + for prefix, replacement in _UPSTREAM_TOP_LEVEL_RENAMES.items(): + if key.startswith(prefix): + renamed = replacement + key[len(prefix) :] + break + converted[renamed] = value + + return converted + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + parser.add_argument( + "--src", + type=Path, + required=True, + help="Path to the upstream LIC-TCM checkpoint (e.g. 0.05.pth.tar).", + ) + parser.add_argument( + "--dst", + type=Path, + default=None, + help=( + "Optional output path for the converted state dict. If omitted, " + "the script only verifies that the checkpoint loads cleanly." + ), + ) + parser.add_argument( + "--smoke", + action="store_true", + help="Run a forward smoke test on a synthetic 256x256 image.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if not args.src.exists(): + raise SystemExit(f"checkpoint not found: {args.src}") + + upstream = torch.load(args.src, map_location="cpu", weights_only=False) + upstream = ( + upstream.get("state_dict", upstream) if isinstance(upstream, dict) else upstream + ) + converted = convert_upstream_tcm_state_dict(upstream) + print(f"loaded {len(upstream)} upstream keys → {len(converted)} compressai keys") + + net = TCM.from_state_dict(converted) + net.eval() + print( + "variant: " + f"N={net.N}, M={net.M}, num_slices={net.num_slices}, " + f"config={tuple(net.config)}, head_dim={tuple(net.head_dim)}, " + f"hyper_channels={net.hyper_channels}, " + f"max_support_slices={net.max_support_slices}" + ) + print(f"parameters: {sum(p.numel() for p in net.parameters()):,}") + + if args.dst is not None: + args.dst.parent.mkdir(parents=True, exist_ok=True) + torch.save(net.state_dict(), args.dst) + print(f"wrote converted state dict → {args.dst}") + + if args.smoke: + height = width = 256 + ys, xs = torch.meshgrid( + torch.linspace(0, 1, height), + torch.linspace(0, 1, width), + indexing="ij", + ) + img = ( + torch.stack( + [ + 0.5 + 0.3 * torch.sin(8 * xs), + 0.5 + 0.3 * torch.sin(8 * ys), + 0.5 + 0.3 * torch.cos(8 * (xs + ys)), + ], + dim=0, + ) + .unsqueeze(0) + .clamp(0, 1) + ) + + with torch.no_grad(): + out = net(img) + n_pix = height * width + psnr = -10 * torch.log10(((out["x_hat"].clamp(0, 1) - img) ** 2).mean()).item() + y_bpp = -torch.log2(out["likelihoods"]["y"]).sum().item() / n_pix + z_bpp = -torch.log2(out["likelihoods"]["z"]).sum().item() / n_pix + print( + f"smoke: PSNR={psnr:.2f}dB y_bpp={y_bpp:.4f} z_bpp={z_bpp:.4f} " + f"total_bpp={y_bpp + z_bpp:.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_latent_codecs.py b/tests/test_latent_codecs.py new file mode 100644 index 00000000..811be35f --- /dev/null +++ b/tests/test_latent_codecs.py @@ -0,0 +1,153 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest +import torch +import torch.nn as nn + +from compressai.latent_codecs import ( + ChannelGroupsLatentCodec, + GaussianConditionalLatentCodec, +) + + +class TestChannelGroupsLatentCodecExtensions: + def _make_codec( + self, + groups=(4, 4, 4), + side_ch=8, + support_slices=None, + ): + K = len(groups) + effective_support = ( + [list(range(k)) for k in range(K)] + if support_slices is None + else support_slices + ) + channel_context = {f"y{k}": nn.Identity() for k in range(1, K)} + + def _ctx_in(k): + if k == 0: + return side_ch + return side_ch + sum(groups[j] for j in effective_support[k]) + + # Each leaf needs an entropy_parameters MLP sized to its own ctx input. + latent_codec = { + f"y{k}": GaussianConditionalLatentCodec( + entropy_parameters=nn.Conv2d(_ctx_in(k), 2 * groups[k], 1), + ) + for k in range(K) + } + return ChannelGroupsLatentCodec( + latent_codec=latent_codec, + channel_context=channel_context, + groups=list(groups), + support_slices=support_slices, + ) + + def test_default_support_slices_uses_all_prior(self): + codec = self._make_codec() + assert codec.support_slices == [(), (0,), (0, 1)] + + def test_explicit_support_slices_are_preserved(self): + support_slices = [[], [0], [0], [0, 2]] + codec = self._make_codec( + groups=(4, 4, 4, 4), + support_slices=support_slices, + ) + assert codec.support_slices == [(), (0,), (0,), (0, 2)] + + def test_support_slices_reject_current_or_future_groups(self): + with pytest.raises(AssertionError): + self._make_codec(groups=(4, 4), support_slices=[[], [1]]) + + def test_default_forward_matches_pre_extension_behaviour(self): + # With defaults the new constructor should be drop-in for ELIC-style use. + torch.manual_seed(7) + codec = self._make_codec() + groups = codec.groups + M = sum(groups) + y = torch.randn(2, M, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + out = codec(y, side_params) + assert out["y_hat"].shape == (2, M, 8, 8) + assert out["likelihoods"]["y"].shape == (2, M, 8, 8) + + def test_explicit_support_slices_runs_forward(self): + torch.manual_seed(3) + codec = self._make_codec( + groups=(4, 4, 4, 4), + support_slices=[[], [0], [0], [0, 2]], + ) + y = torch.randn(2, 16, 8, 8) + side_params = torch.randn(2, 8, 8, 8) + out = codec(y, side_params) + assert out["y_hat"].shape == (2, 16, 8, 8) + + +class TestChannelGroupsDecompressShape: + """Coverage for the unified channel-first latent-codec shape convention.""" + + class _LeafMock(nn.Module): + def __init__(self, slice_ch): + super().__init__() + self.slice_ch = slice_ch + + def compress(self, y, ctx_params): + n = y.shape[0] + return { + "strings": [[b"" for _ in range(n)]], + "shape": tuple(y.shape[1:]), + "y_hat": torch.zeros_like(y), + } + + def decompress(self, strings, shape, ctx_params, **kwargs): + n = len(strings[0]) + c, h, w = shape + assert c == self.slice_ch + return {"y_hat": torch.zeros((n, c, h, w))} + + def _make_codec(self, groups=(4, 4, 4)): + K = len(groups) + return ChannelGroupsLatentCodec( + latent_codec={f"y{k}": self._LeafMock(groups[k]) for k in range(K)}, + channel_context={f"y{k}": nn.Identity() for k in range(1, K)}, + groups=list(groups), + ) + + def test_decompress_passes_per_group_channel_shape(self): + groups = [4, 4, 4] + codec = self._make_codec(groups=groups) + h, w = 6, 5 + y = torch.randn(1, sum(groups), h, w) + side_params = torch.zeros(1, 8, h, w) + out_enc = codec.compress(y, side_params) + assert out_enc["shape"] == y.shape[1:] + out_dec = codec.decompress(out_enc["strings"], out_enc["shape"], side_params) + assert out_dec["y_hat"].shape == (1, sum(groups), h, w) diff --git a/tests/test_models.py b/tests/test_models.py index c23b1865..ff9ef958 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,6 +27,10 @@ # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import importlib.util + +from pathlib import Path + import pytest import torch import torch.nn as nn @@ -51,6 +55,24 @@ from compressai.models.vbr import ScaleHyperpriorVbr from compressai.models.video.google import ScaleSpaceFlow +_EXAMPLES_DIR = Path(__file__).resolve().parent.parent / "examples" + + +def _load_convert_fn(script_name: str, fn_name: str): + """Load a ``convert_upstream_*_state_dict`` function from an + ``examples/convert_*_checkpoint.py`` script. + + The upstream-checkpoint conversion helpers live in the example CLI + scripts (not in ``compressai.models.*``) so the model modules stay + clean compressai-native definitions. ``examples/`` is not an importable + package, so we load the module by file path. + """ + path = _EXAMPLES_DIR / script_name + spec = importlib.util.spec_from_file_location(path.stem, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, fn_name) + class DummyCompressionModel(CompressionModel): def __init__(self, entropy_bottleneck_channels): @@ -290,6 +312,27 @@ def test_wacnn_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] + # Containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # Side-parameter channel-context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic paths should be gone. + assert not any( + k.startswith("latent_codec.cc_mean_transforms.") for k in sd_keys + ) + assert "h_a.0.weight" not in sd_keys # moved under latent_codec. + loaded = WACNN.from_state_dict(model.state_dict()).eval() with torch.no_grad(): out_loaded = loaded(x) @@ -312,27 +355,469 @@ def test_symmetrical_transformer_forward_and_state_dict_round_trip(self): assert "y" in out["likelihoods"] assert "z" in out["likelihoods"] + sd_keys = set(model.state_dict().keys()) + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + loaded = SymmetricalTransFormer.from_state_dict(model.state_dict()).eval() with torch.no_grad(): out_loaded = loaded(x) assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) def test_stf_upstream_state_dict_conversion(self): - from compressai.models.stf import ( - convert_upstream_stf_state_dict, + convert_upstream_stf_state_dict = _load_convert_fn( + "convert_stf_checkpoint.py", "convert_upstream_stf_state_dict" ) upstream = { "module.g_a.0.weight": torch.zeros(2), "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.cc_mean_transforms.1.0.weight": torch.zeros(2), + "module.cc_scale_transforms.0.0.weight": torch.zeros(2), + "module.lrp_transforms.0.0.weight": torch.zeros(2), "module.gaussian_conditional.scale_table": torch.zeros(2), "module.h_a.0.weight": torch.zeros(2), + "module.h_mean_s.0.weight": torch.zeros(2), + "module.h_scale_s.0.weight": torch.zeros(2), + "module.entropy_bottleneck.quantiles": torch.zeros(2), } converted = convert_upstream_stf_state_dict(upstream) + # g_a passes through unchanged. assert "g_a.0.weight" in converted - assert "latent_codec.cc_mean_transforms.0.0.weight" in converted - assert "latent_codec.gaussian_conditional.scale_table" in converted - assert "h_a.0.weight" in converted + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + # cc_mean / cc_scale re-rooted per slice. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + # gaussian_conditional replicated to every slice (driven by mean_cc count). + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + # LRP weights are now retained: emit_mean_support=True on the head + # makes the leaf consume cat(latent_means, *prev_y_hat) as the LRP + # input, matching upstream's M + slice_ch*(support+1) input width. + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + # Old root-level paths should be gone after conversion. + assert "h_a.0.weight" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + + +class TestTcm: + def test_tcm_forward_and_state_dict_round_trip(self): + from compressai.models.tcm import TCM + + model = TCM( + N=32, + M=64, + hyper_channels=48, + num_slices=4, + max_support_slices=2, + ).eval() + x = torch.rand(1, 3, 64, 64) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + + # Containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + # Hyperprior backbone moved under latent_codec.* (TCM's h_a / h_*_s + # use ResidualBlockWithStride / ResidualBlockUpsample, so the first + # learnable weight is conv1 / conv). + assert "latent_codec.h_a.0.conv1.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.conv.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.conv.weight" in sd_keys + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # Side-parameter channel-context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # SWAtten support transforms (TCM-specific; absent on STF/WACNN). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.in_conv.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.in_conv.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic / pr-stf-wacnn paths should be gone. + assert not any( + k.startswith("latent_codec.cc_mean_transforms.") for k in sd_keys + ) + assert not any(k.startswith("latent_codec.atten_mean.") for k in sd_keys) + assert "h_a.0.conv1.weight" not in sd_keys # moved under latent_codec. + + loaded = TCM.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert loaded.N == 32 + assert loaded.M == 64 + assert loaded.hyper_channels == 48 + assert loaded.num_slices == 4 + assert loaded.max_support_slices == 2 + + def test_tcm_upstream_state_dict_conversion(self): + convert_upstream_tcm_state_dict = _load_convert_fn( + "convert_tcm_checkpoint.py", "convert_upstream_tcm_state_dict" + ) + + # Synthetic upstream LIC_TCM-style state_dict: DataParallel ``module.`` + # prefix, raw entropy heads at the root, the SWAtten ``nn.Sequential`` + # wrapper level (``atten_mean.{k}.0.``), and a ConvTransBlock attention + # buffer in upstream layout (``.msa.relative_position_params``). + upstream = { + "module.g_a.0.conv1.weight": torch.zeros(2), + "module.g_a.1.trans_block.msa.relative_position_params": torch.zeros( + 4, 15, 15 + ), + "module.g_a.1.trans_block.msa.embedding_layer.weight": torch.zeros(2), + "module.g_a.1.trans_block.ln1.weight": torch.zeros(2), + "module.g_a.1.trans_block.mlp.0.weight": torch.zeros(2), + "module.g_a.1.trans_block.mlp.2.weight": torch.zeros(2), + "module.cc_mean_transforms.0.0.weight": torch.zeros(2), + "module.cc_mean_transforms.1.0.weight": torch.zeros(2), + "module.cc_scale_transforms.0.0.weight": torch.zeros(2), + "module.atten_mean.0.0.in_conv.weight": torch.zeros(2), + "module.atten_scale.0.0.in_conv.weight": torch.zeros(2), + "module.lrp_transforms.0.0.weight": torch.zeros(2), + "module.gaussian_conditional.scale_table": torch.zeros(2), + "module.h_a.0.conv1.weight": torch.zeros(2), + "module.h_mean_s.0.conv.weight": torch.zeros(2), + "module.h_scale_s.0.conv.weight": torch.zeros(2), + "module.entropy_bottleneck.quantiles": torch.zeros(2), + } + converted = convert_upstream_tcm_state_dict(upstream) + + # ``module.`` prefix gone; g_a / ConvTransBlock pass through with the + # MSA / layer-name renames applied. + assert "g_a.0.conv1.weight" in converted + # ``relative_position_params`` -> ``relative_position_bias_table`` + # with shape permuted from (2*win-1, 2*win-1, num_heads) = + # (15, 15, 4) into the flat (225, 4) layout. + assert "g_a.1.trans_block.msa.attn.relative_position_bias_table" in converted + assert converted[ + "g_a.1.trans_block.msa.attn.relative_position_bias_table" + ].shape == (15 * 15, 4) + # ``embedding_layer`` -> ``attn.qkv``. + assert "g_a.1.trans_block.msa.attn.qkv.weight" in converted + # ``ln1`` -> ``norm1``; ``mlp.0`` / ``mlp.2`` -> ``mlp.fc1`` / ``fc2``. + assert "g_a.1.trans_block.norm1.weight" in converted + assert "g_a.1.trans_block.mlp.fc1.weight" in converted + assert "g_a.1.trans_block.mlp.fc2.weight" in converted + + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.conv1.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.conv.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.conv.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # cc_mean / cc_scale re-rooted per slice. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y1.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + + # SWAtten wrapper unwrapped: ``atten_mean.0.0.<...>`` -> + # ``...mean_support_transform.<...>`` (no extra ``.0`` level). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.in_conv.weight" + in converted + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.in_conv.weight" + in converted + ) + + # gaussian_conditional replicated per slice (driven by mean_cc count). + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert ( + "latent_codec.y.latent_codec.y1.gaussian_conditional.scale_table" + in converted + ) + + # LRP weights retained byte-for-byte (emit_mean_support=True path). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + + # Old root-level paths should be gone after conversion. + assert "h_a.0.conv1.weight" not in converted + assert "cc_mean_transforms.0.0.weight" not in converted + assert "atten_mean.0.0.in_conv.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "module.g_a.0.conv1.weight" not in converted + + +class TestCca: + def test_cca_forward_and_state_dict_round_trip(self): + from compressai.models.cca import CCAModel + + # Tiny variant — variable-length slices, smaller dims keep the + # NAFTransform stack cheap. Slice proportions reproduce the + # upstream layout (8/28/56/92/136 over M=320) at scale. + model = CCAModel( + latent_channels=64, + hyper_channels=48, + slice_proportions=(2, 6, 12, 18, 26), + encoder_dims=(48, 56, 64), + encoder_layers=(1, 1, 1), + em_hidden_channels=56, + em_num_layers=1, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + assert out["x_hat"].shape == x.shape + assert "y" in out["likelihoods"] + assert "z" in out["likelihoods"] + # cca_training=False -> no aux likelihoods exposed. + assert out["aux_likelihoods"] is None + + # Containerised state-dict layout self-check. + sd_keys = set(model.state_dict().keys()) + # Hyperprior backbone moved under latent_codec.* (CCA's h_a / + # h_*_s use plain Sequentials of conv / GELU; first weight is at + # `.0.weight` rather than `.0.conv1.weight`). + assert "latent_codec.h_a.0.weight" in sd_keys + assert "latent_codec.h_s.h_mean_s.0.weight" in sd_keys + assert "latent_codec.h_s.h_scale_s.0.weight" in sd_keys + # STF/WACNN/TCM/CCA use STE on z; the + # entropy_bottleneck still owns the parametric prior. + assert "latent_codec.z.entropy_bottleneck.quantiles" in sd_keys + # Side-parameter channel-context covers y0..y(K-1). + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y4.mean_cc.0.weight" in sd_keys + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in sd_keys + # NAFTransform support transforms (CCA-specific; absent on STF/WACNN). + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + in sd_keys + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.input_projection.weight" + in sd_keys + ) + # Per-slice leaves (LRP + per-slice GaussianConditional copy). + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in sd_keys + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" in sd_keys + ) + # Old monolithic / pre-refactor paths should be gone. + assert "h_a.0.weight" not in sd_keys + assert "z_entropy_bottleneck.quantiles" not in sd_keys + assert not any(k.startswith("mean_cc_transforms.") for k in sd_keys) + assert not any(k.startswith("aux_entropy_model.") for k in sd_keys) + + loaded = CCAModel.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert torch.allclose(out["x_hat"], out_loaded["x_hat"]) + assert torch.allclose(out["likelihoods"]["y"], out_loaded["likelihoods"]["y"]) + assert torch.allclose(out["likelihoods"]["z"], out_loaded["likelihoods"]["z"]) + assert loaded.M == 64 + assert loaded.N == 48 + assert tuple(loaded.slice_sizes) == (2, 6, 12, 18, 26) + assert loaded.em_hidden_channels == 56 + assert loaded.em_num_layers == 1 + assert loaded.cca_training is False + + def test_cca_training_branch_forward_and_round_trip(self): + from compressai.models.cca import CCAModel + + model = CCAModel( + latent_channels=64, + hyper_channels=48, + slice_proportions=(2, 6, 12, 18, 26), + encoder_dims=(48, 56, 64), + encoder_layers=(1, 1, 1), + em_hidden_channels=56, + em_num_layers=1, + cca_training=True, + ).eval() + x = torch.rand(1, 3, 128, 128) + with torch.no_grad(): + out = model(x) + # Aux branch populates y_aux (factorised) and y_cca (Gaussian). + assert isinstance(out["aux_likelihoods"], dict) + assert set(out["aux_likelihoods"].keys()) == {"y_aux", "y_cca"} + assert out["aux_likelihoods"]["y_aux"].shape == out["likelihoods"]["y"].shape + assert out["aux_likelihoods"]["y_cca"].shape == out["likelihoods"]["y"].shape + + # Aux state-dict paths (skip-most-recent inner ChannelGroupsLatentCodec). + sd_keys = set(model.state_dict().keys()) + assert "aux_entropy_model.y_entropy_bottleneck.quantiles" in sd_keys + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_cc.0.weight" + in sd_keys + ) + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_support_transform.input_projection.weight" + in sd_keys + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.lrp_transform.0.weight" + in sd_keys + ) + + loaded = CCAModel.from_state_dict(model.state_dict()).eval() + with torch.no_grad(): + out_loaded = loaded(x) + assert loaded.cca_training is True + assert torch.allclose( + out["aux_likelihoods"]["y_aux"], out_loaded["aux_likelihoods"]["y_aux"] + ) + assert torch.allclose( + out["aux_likelihoods"]["y_cca"], out_loaded["aux_likelihoods"]["y_cca"] + ) + + def test_cca_upstream_state_dict_conversion(self): + convert_upstream_cca_state_dict = _load_convert_fn( + "convert_cca_checkpoint.py", "convert_upstream_cca_state_dict" + ) + _is_upstream_cca_state_dict = _load_convert_fn( + "convert_cca_checkpoint.py", "_is_upstream_cca_state_dict" + ) + + # Synthetic upstream LICAutoencoder-style state_dict with one slice + # per branch covering the full path: NAFBlock interior renames, + # NAFTransform interior renames, named-part NAF -> support_transforms + # alias, top-level hyperprior + aux module rerooting, per-slice + # rerooting under channel_context / latent_codec, and the + # gaussian_conditional replication. + upstream = { + # ResidualBottleneckBlock inside g_a (conv1 should NOT be renamed + # since it's not inside a NAFBlock — checked via the NAFBlock + # detector which requires the .beta/.gamma/.dwconv.0 triple). + "g_a.blocks.0.0.conv1.weight": torch.zeros(2), + # NAFBlock inside g_a (full triple present -> dwconv/sca/FFN/conv1 + # interior renames apply to this scope only). + "g_a.blocks.0.3.beta": torch.zeros(2), + "g_a.blocks.0.3.gamma": torch.zeros(2), + "g_a.blocks.0.3.dwconv.0.weight": torch.zeros(2), + "g_a.blocks.0.3.sca.1.weight": torch.zeros(2), + "g_a.blocks.0.3.FFN.0.weight": torch.zeros(2), + "g_a.blocks.0.3.conv1.weight": torch.zeros(2), + # Per-slice main entropy heads (one slice for compactness). + "mean_cc_transforms.0.0.weight": torch.zeros(2), + "scale_cc_transforms.0.0.weight": torch.zeros(2), + "lrp_transforms.0.0.weight": torch.zeros(2), + # NAFTransform interior (in_conv/out_conv -> input_projection/...). + # Triple required for the detector: .in_conv.weight, + # .out_conv.weight, .blocks.0.beta. + "mean_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "mean_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "mean_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "scale_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "scale_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "scale_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "gaussian_conditional.scale_table": torch.zeros(2), + # Hyperprior backbone (root-level -> latent_codec.*). + "h_a.0.weight": torch.zeros(2), + "h_mean_s.0.weight": torch.zeros(2), + "h_scale_s.0.weight": torch.zeros(2), + "z_entropy_bottleneck.quantiles": torch.zeros(2), + # Aux entropy module (aux_entropymodel -> aux_entropy_model, then + # the same per-slice rerooting as the main path). + "aux_entropymodel.mean_cc_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.scale_cc_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.lrp_transforms.0.0.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "aux_entropymodel.mean_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.in_conv.weight": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.out_conv.weight": torch.zeros(2), + "aux_entropymodel.scale_NAF_transforms.0.blocks.0.beta": torch.zeros(2), + "aux_entropymodel.gaussian_conditional.scale_table": torch.zeros(2), + "aux_entropymodel.y_entropy_bottleneck.quantiles": torch.zeros(2), + } + assert _is_upstream_cca_state_dict(upstream) + + converted = convert_upstream_cca_state_dict(upstream) + + # ResidualBottleneckBlock conv1 NOT renamed (not inside NAFBlock). + assert "g_a.blocks.0.0.conv1.weight" in converted + # NAFBlock interior renames applied at the NAFBlock scope only. + assert "g_a.blocks.0.3.beta" in converted + assert "g_a.blocks.0.3.pointwise_depthwise.0.weight" in converted + assert "g_a.blocks.0.3.channel_attention.1.weight" in converted + assert "g_a.blocks.0.3.feed_forward.0.weight" in converted + assert "g_a.blocks.0.3.project.weight" in converted + + # Hyperprior backbone moves under latent_codec. + assert "latent_codec.h_a.0.weight" in converted + assert "latent_codec.h_s.h_mean_s.0.weight" in converted + assert "latent_codec.h_s.h_scale_s.0.weight" in converted + assert "latent_codec.z.entropy_bottleneck.quantiles" in converted + + # Per-slice main rerooting. + assert "latent_codec.y.channel_context.y0.mean_cc.0.weight" in converted + assert "latent_codec.y.channel_context.y0.scale_cc.0.weight" in converted + # NAFTransform: in_conv -> input_projection; mean_NAF_transforms -> + # channel_context.y{k}.mean_support_transform. + assert ( + "latent_codec.y.channel_context.y0.mean_support_transform.input_projection.weight" + in converted + ) + assert ( + "latent_codec.y.channel_context.y0.scale_support_transform.input_projection.weight" + in converted + ) + # gaussian_conditional replicated under each per-slice leaf. + assert ( + "latent_codec.y.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + # LRP weights byte-for-byte under per-slice leaf. + assert "latent_codec.y.latent_codec.y0.lrp_transform.0.weight" in converted + + # Aux entropy module rerooting (aux_entropymodel -> aux_entropy_model; + # per-slice contents land under inner_codec.*). + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_cc.0.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.channel_context.y0.mean_support_transform.input_projection.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.lrp_transform.0.weight" + in converted + ) + assert ( + "aux_entropy_model.inner_codec.latent_codec.y0.gaussian_conditional.scale_table" + in converted + ) + assert "aux_entropy_model.y_entropy_bottleneck.quantiles" in converted + + # Old paths should be gone after conversion. + assert "h_a.0.weight" not in converted + assert "z_entropy_bottleneck.quantiles" not in converted + assert "mean_cc_transforms.0.0.weight" not in converted + assert "mean_NAF_transforms.0.in_conv.weight" not in converted + assert "lrp_transforms.0.0.weight" not in converted + assert "aux_entropymodel.mean_cc_transforms.0.0.weight" not in converted def test_scale_table_default(): diff --git a/tests/test_models_helpers.py b/tests/test_models_helpers.py new file mode 100644 index 00000000..2b54f0bc --- /dev/null +++ b/tests/test_models_helpers.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021-2025, InterDigital Communications, Inc +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted (subject to the limitations in the disclaimer +# below) provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# * Neither the name of InterDigital Communications, Inc nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. + +# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY +# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT +# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn as nn + +from compressai.models._helpers.channel_context import MeanScaleContextHead +from compressai.models._helpers.slice_helpers import ( + infer_max_support_slices, + infer_num_slices, + lrp_support_channels, + make_entropy_transform, + slice_support_channels, +) + + +class TestMeanScaleContextHead: + def test_forward_shape_concatenates_mean_and_scale(self): + slice_ch, support_ch = 4, 12 + head = MeanScaleContextHead( + mean_cc=make_entropy_transform(support_ch, slice_ch, widths=(8, 8)), + scale_cc=make_entropy_transform(support_ch, slice_ch, widths=(8, 8)), + ) + x = torch.randn(2, support_ch, 4, 4) + out = head(x) + assert out.shape == (2, 2 * slice_ch, 4, 4) + + def test_state_dict_paths_split_mean_and_scale(self): + head = MeanScaleContextHead( + mean_cc=make_entropy_transform(12, 4, widths=(8, 8)), + scale_cc=make_entropy_transform(12, 4, widths=(8, 8)), + ) + keys = set(head.state_dict().keys()) + assert any(k.startswith("mean_cc.") for k in keys) + assert any(k.startswith("scale_cc.") for k in keys) + # No support transforms by default -> no associated state. + assert not any(k.startswith("mean_support_transform.") for k in keys) + assert not any(k.startswith("scale_support_transform.") for k in keys) + + def test_support_transforms_wrap_inputs(self): + # Use 1x1 conv that preserves channel count. + head = MeanScaleContextHead( + mean_cc=make_entropy_transform(12, 4, widths=(8, 8)), + scale_cc=make_entropy_transform(12, 4, widths=(8, 8)), + mean_support_transform=nn.Conv2d(12, 12, 1), + scale_support_transform=nn.Conv2d(12, 12, 1), + ) + keys = set(head.state_dict().keys()) + assert any(k.startswith("mean_support_transform.") for k in keys) + assert any(k.startswith("scale_support_transform.") for k in keys) + # mean and scale support transforms are independent instances (not shared). + assert head.mean_support_transform is not head.scale_support_transform + + def test_direct_construction_round_trip(self): + torch.manual_seed(0) + mean_cc = nn.Conv2d(12, 4, 1) + scale_cc = nn.Conv2d(12, 4, 1) + head = MeanScaleContextHead(mean_cc=mean_cc, scale_cc=scale_cc) + rebuilt = MeanScaleContextHead( + mean_cc=nn.Conv2d(12, 4, 1), scale_cc=nn.Conv2d(12, 4, 1) + ) + rebuilt.load_state_dict(head.state_dict()) + x = torch.randn(2, 12, 4, 4) + with torch.no_grad(): + assert torch.allclose(head(x), rebuilt(x)) + + def test_side_split_routes_means_to_mean_cc_and_scales_to_scale_cc(self): + # side_split=8 means input is cat(latent_means(8), latent_scales(8), prev_y_hat(4)); + # mean_cc should see cat(latent_means(8), prev_y_hat(4)) = 12 channels; + # scale_cc same width but reading latent_scales instead of latent_means. + torch.manual_seed(0) + head = MeanScaleContextHead( + mean_cc=make_entropy_transform(12, 4, widths=(8,)), + scale_cc=make_entropy_transform(12, 4, widths=(8,)), + side_split=8, + ) + # Sub-network input width = support_ch - side_split = 12. + first_mean_conv = next(m for m in head.mean_cc if isinstance(m, nn.Conv2d)) + first_scale_conv = next(m for m in head.scale_cc if isinstance(m, nn.Conv2d)) + assert first_mean_conv.in_channels == 12 + assert first_scale_conv.in_channels == 12 + + latent_means = torch.randn(2, 8, 4, 4) + latent_scales = torch.randn(2, 8, 4, 4) + prev_y_hat = torch.randn(2, 4, 4, 4) + x = torch.cat([latent_means, latent_scales, prev_y_hat], dim=1) + with torch.no_grad(): + head_out = head(x) + assert head_out.shape == (2, 8, 4, 4) + # Verify routing: mean_cc(cat(latent_means, prev_y_hat)) appears as + # the second half of head_out (chunks=("scales","means")). + with torch.no_grad(): + expected_mean = head.mean_cc(torch.cat([latent_means, prev_y_hat], dim=1)) + expected_scale = head.scale_cc( + torch.cat([latent_scales, prev_y_hat], dim=1) + ) + scale_out, mean_out = head_out.chunk(2, dim=1) + assert torch.allclose(scale_out, expected_scale) + assert torch.allclose(mean_out, expected_mean) + + +class TestSliceHelpers: + def test_slice_support_channels_default_use_all(self): + # With max_support_slices = -1 the helper returns the full latent + k slices. + assert slice_support_channels(64, 8, 0, -1) == 64 + assert slice_support_channels(64, 8, 5, -1) == 64 + 8 * 5 + + def test_slice_support_channels_clamps(self): + assert slice_support_channels(64, 8, 5, 3) == 64 + 8 * 3 + assert slice_support_channels(64, 8, 1, 3) == 64 + 8 * 1 + + def test_lrp_support_channels(self): + assert lrp_support_channels(64, 8, 0, -1) == 64 + 8 + assert lrp_support_channels(64, 8, 5, 3) == 64 + 8 * 4 + + def test_make_entropy_transform_default_widths(self): + net = make_entropy_transform(40, 8) + # Default widths (224, 128): conv-gelu-conv-gelu-conv -> 5 modules. + assert len(net) == 5 + x = torch.randn(2, 40, 8, 8) + y = net(x) + assert y.shape == (2, 8, 8, 8) + + def test_make_entropy_transform_custom_widths(self): + net = make_entropy_transform(40, 8, widths=(64, 32)) + x = torch.randn(2, 40, 8, 8) + y = net(x) + assert y.shape == (2, 8, 8, 8) + + def test_infer_num_slices_new_path(self): + # New state-dict layout (ELIC default): channel_context entries exist + # for k >= 1. For 4 slices total, we expect 3 mean_cc keys -> infer + # returns 4 (helper adds 1 because y0 is missing). + sd = { + f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight": (torch.zeros(8, 4)) + for k in range(1, 4) + } + assert infer_num_slices(sd) == 4 + + def test_infer_num_slices_with_y0_context(self): + # Side-parameter channel-context layout: channel_context covers every + # slice (y0..yK-1). Helper auto-detects via the presence of y0 and + # does NOT add 1. + sd = { + f"latent_codec.y.channel_context.y{k}.mean_cc.0.weight": (torch.zeros(8, 4)) + for k in range(0, 4) + } + assert infer_num_slices(sd) == 4 + + def test_infer_num_slices_empty(self): + assert infer_num_slices({}) == 0 + + def test_infer_max_support_slices_new_path(self): + # mean_cc.0 takes (latent_means + slice_channels * support) input channels. + # With M=64, num_slices=8, slice_channels=8, support=2 -> input ch = 64 + 16 = 80. + sd = { + "latent_codec.y.channel_context.y2.mean_cc.0.weight": ( + torch.zeros(64, 80, 3, 3) + ), + "latent_codec.y.channel_context.y3.mean_cc.0.weight": ( + torch.zeros(64, 80, 3, 3) + ), + } + # extra_factor=1 is the default single latent_means concat. + assert infer_max_support_slices(sd, latent_channels=64, num_slices=8) == 2