From 290c6c5624806694a8e5e502f62901f8f5b589fb Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 27 May 2026 17:01:15 +0200 Subject: [PATCH 1/5] feat: implemented selective component loading in AppState and AppStateFactory --- .../checkpointing/stateful/app_state.py | 30 ++++++++++++++----- .../stateful/app_state_factory.py | 23 ++++++++++---- src/modalities/config/config.py | 2 ++ 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 2da3ab236..97b65a569 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -37,7 +37,11 @@ class AppState(Stateful): """ def __init__( - self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + self, + model: nn.Module | list[nn.Module], + optimizer: Optimizer, + lr_scheduler: Optional[LRScheduler] = None, + components_to_load: list[StatefulComponents] | None = None, ): """Initializes the AppState object. @@ -46,12 +50,22 @@ def __init__( a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. + components_to_load (list[StatefulComponents] | None, optional): The list of components to load from the + checkpoint. If None, all components are loaded. Defaults to None. """ self._model_parts = list(model) if isinstance(model, list) else [model] self._optimizer = optimizer self._lr_scheduler = lr_scheduler self._is_loaded = False + # policy for which components to load from the checkpoint. If None, defaults to loading all components. + if components_to_load is None: + self._components_to_load = [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER] + if lr_scheduler is not None: + self._components_to_load.append(StatefulComponents.LR_SCHEDULER) + else: + self._components_to_load = components_to_load + @property def is_loaded(self) -> bool: """Returns whether the state dict has been loaded. @@ -106,12 +120,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." ) - ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) - OptimizerStateRetriever.load_state_dict_( - app_state=self, - state_dict=state_dict[StatefulComponents.OPTIMIZER.value], - ) - if self._lr_scheduler is not None: + if StatefulComponents.MODEL in self._components_to_load: + ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) + if StatefulComponents.OPTIMIZER in self._components_to_load: + OptimizerStateRetriever.load_state_dict_( + app_state=self, + state_dict=state_dict[StatefulComponents.OPTIMIZER.value], + ) + if self._lr_scheduler is not None and StatefulComponents.LR_SCHEDULER in self._components_to_load: LRSchedulerStateRetriever.load_state_dict_( app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value] ) diff --git a/src/modalities/checkpointing/stateful/app_state_factory.py b/src/modalities/checkpointing/stateful/app_state_factory.py index 8f6e63d8a..8c33f73d2 100644 --- a/src/modalities/checkpointing/stateful/app_state_factory.py +++ b/src/modalities/checkpointing/stateful/app_state_factory.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading -from modalities.checkpointing.stateful.app_state import AppState +from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents class AppStateFactory: @@ -15,7 +15,10 @@ class AppStateFactory: @staticmethod def get_raw_app_state( - model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + model: nn.Module | list[nn.Module], + optimizer: Optimizer, + lr_scheduler: Optional[LRScheduler] = None, + components_to_load: list[StatefulComponents] | None = None, ) -> AppState: """Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler. @@ -25,11 +28,19 @@ def get_raw_app_state( a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None. + components_to_load (list[StatefulComponents] | None, optional): Subset of components that should + be restored from a checkpoint when ``load_state_dict`` is later invoked. If None, all + available components are loaded. Defaults to None. Returns: AppState: The AppState object. """ - app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) + app_state = AppState( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + components_to_load=components_to_load, + ) return app_state @staticmethod @@ -41,7 +52,8 @@ def get_dcp_checkpointed_app_state_( (i.e., non-checkpoint loaded AppState) in-place. Args: - raw_app_state (AppState): The raw AppState object. + raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy + determines which components are restored. checkpoint_dir_path (Path): The path to the checkpoint directory. Raises: @@ -52,8 +64,9 @@ def get_dcp_checkpointed_app_state_( """ if raw_app_state.is_loaded: raise RuntimeError( - "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." + "Cannot call load_state_dict twice on the same AppState object. State dict has already been loaded." ) + cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank()) cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path) return raw_app_state diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..bb2a98dc0 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -11,6 +11,7 @@ from transformers import LlamaTokenizer as LlamaTokenizerFast from typing_extensions import deprecated +from modalities.checkpointing.stateful.app_state import StatefulComponents from modalities.config.lookup_enum import LookupEnum from modalities.config.pydantic_if_types import ( PydanticAppStateType, @@ -382,6 +383,7 @@ class RawAppStateConfig(BaseModel): model: PydanticPytorchModuleOrListType optimizer: PydanticOptimizerIFType lr_scheduler: Optional[PydanticLRSchedulerIFType] = None + components_to_load: Optional[list[StatefulComponents]] = None class DCPAppStateConfig(BaseModel): From f08c93750cd4e8f1281c6b07a23091960ce399e3 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 27 May 2026 17:03:02 +0200 Subject: [PATCH 2/5] test: added test for partial checkpoint loading --- .../test_app_state_components_to_load.py | 137 ++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 tests/checkpointing/test_app_state_components_to_load.py diff --git a/tests/checkpointing/test_app_state_components_to_load.py b/tests/checkpointing/test_app_state_components_to_load.py new file mode 100644 index 000000000..acd842d86 --- /dev/null +++ b/tests/checkpointing/test_app_state_components_to_load.py @@ -0,0 +1,137 @@ +from unittest.mock import MagicMock + +import pytest +import torch.nn as nn +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR + +from modalities.checkpointing.stateful import app_state as app_state_module +from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents + + +@pytest.fixture +def model() -> nn.Module: + return nn.Linear(4, 2) + + +@pytest.fixture +def optimizer(model: nn.Module) -> SGD: + return SGD(model.parameters(), lr=0.1) + + +@pytest.fixture +def lr_scheduler(optimizer: SGD) -> StepLR: + return StepLR(optimizer, step_size=1) + + +@pytest.fixture +def patched_retrievers(monkeypatch: pytest.MonkeyPatch) -> dict[StatefulComponents, MagicMock]: + """Replace each retriever's ``load_state_dict_`` with a mock so we can assert which ones were invoked.""" + mocks = { + StatefulComponents.MODEL: MagicMock(), + StatefulComponents.OPTIMIZER: MagicMock(), + StatefulComponents.LR_SCHEDULER: MagicMock(), + } + monkeypatch.setattr(app_state_module.ModelStateRetriever, "load_state_dict_", mocks[StatefulComponents.MODEL]) + monkeypatch.setattr( + app_state_module.OptimizerStateRetriever, "load_state_dict_", mocks[StatefulComponents.OPTIMIZER] + ) + monkeypatch.setattr( + app_state_module.LRSchedulerStateRetriever, "load_state_dict_", mocks[StatefulComponents.LR_SCHEDULER] + ) + return mocks + + +def _make_state_dict() -> dict: + return { + StatefulComponents.MODEL.value: {"model_payload": True}, + StatefulComponents.OPTIMIZER.value: {"optimizer_payload": True}, + StatefulComponents.LR_SCHEDULER.value: {"lr_scheduler_payload": True}, + } + + +class TestComponentsToLoad: + def test_default_without_lr_scheduler_loads_model_and_optimizer( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + app_state = AppState(model=model, optimizer=optimizer) + + app_state.load_state_dict(_make_state_dict()) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called() + assert app_state.is_loaded + + def test_default_with_lr_scheduler_loads_all_three( + self, + model: nn.Module, + optimizer: SGD, + lr_scheduler: StepLR, + patched_retrievers: dict[StatefulComponents, MagicMock], + ) -> None: + app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) + + app_state.load_state_dict(_make_state_dict()) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_called_once() + + @pytest.mark.parametrize( + "selected", + [ + [StatefulComponents.MODEL], + [StatefulComponents.OPTIMIZER], + [StatefulComponents.LR_SCHEDULER], + [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER], + [StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER], + [], + ], + ) + def test_explicit_selection_only_loads_chosen_components( + self, + model: nn.Module, + optimizer: SGD, + lr_scheduler: StepLR, + patched_retrievers: dict[StatefulComponents, MagicMock], + selected: list[StatefulComponents], + ) -> None: + app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, components_to_load=selected) + + app_state.load_state_dict(_make_state_dict()) + + for component, mock in patched_retrievers.items(): + if component in selected: + mock.assert_called_once() + else: + mock.assert_not_called() + + def test_lr_scheduler_in_components_but_no_scheduler_attached_is_skipped( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + # Guards against the lr_scheduler branch firing when no scheduler is attached — the + # state_dict won't carry a scheduler entry, so the retriever must not be called. + app_state = AppState( + model=model, + optimizer=optimizer, + components_to_load=[StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER], + ) + + state_dict = _make_state_dict() + state_dict.pop(StatefulComponents.LR_SCHEDULER.value) + + app_state.load_state_dict(state_dict) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_not_called() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called() + + def test_double_load_raises( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + app_state = AppState(model=model, optimizer=optimizer) + app_state.load_state_dict(_make_state_dict()) + + with pytest.raises(RuntimeError): + app_state.load_state_dict(_make_state_dict()) From ec1ac4f75181d7a691364b251e560994060373da Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 28 May 2026 12:54:19 +0200 Subject: [PATCH 3/5] feat: added allow_partial_load option to DCP checkpoint loading --- .../checkpointing/fsdp/fsdp_checkpoint_loading.py | 6 ++++-- src/modalities/checkpointing/stateful/app_state_factory.py | 4 +++- src/modalities/config/config.py | 6 ++++-- src/modalities/registry/components.py | 5 ++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py index 572168972..8f476530d 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py @@ -103,17 +103,18 @@ def load_optimizer_checkpoint_(self, optimizer: Optimizer, model: FSDP, file_pat class DCPCheckpointLoading(DistributedCheckpointLoadingIF): """Distributed checkpoint loading interface for loading PyTorch models and optimizer checkpoints.""" - def __init__(self, global_rank: int): + def __init__(self, global_rank: int, allow_partial_load: bool = False): """ Initializes the DCPCheckpointLoading object. Args: global_rank (int): The global rank of the process. - + allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True. Returns: None """ self._global_rank = global_rank + self._allow_partial_load = allow_partial_load @torch.no_grad() def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path): @@ -129,5 +130,6 @@ def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path): dcp.load( state_dict={"app": app_state}, checkpoint_id=checkpoint_dir_path, + planner=dcp.DefaultLoadPlanner(allow_partial_load=self._allow_partial_load), ) get_logger().info(f"Distributed checkpoint loaded on rank {self._global_rank}.") diff --git a/src/modalities/checkpointing/stateful/app_state_factory.py b/src/modalities/checkpointing/stateful/app_state_factory.py index 8c33f73d2..618794ac5 100644 --- a/src/modalities/checkpointing/stateful/app_state_factory.py +++ b/src/modalities/checkpointing/stateful/app_state_factory.py @@ -47,6 +47,7 @@ def get_raw_app_state( def get_dcp_checkpointed_app_state_( raw_app_state: AppState, checkpoint_dir_path: Path, + allow_partial_load: bool = True, ) -> AppState: """Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place. @@ -55,6 +56,7 @@ def get_dcp_checkpointed_app_state_( raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy determines which components are restored. checkpoint_dir_path (Path): The path to the checkpoint directory. + allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True. Raises: RuntimeError: Raises an error if the state dict has already been loaded. @@ -67,6 +69,6 @@ def get_dcp_checkpointed_app_state_( "Cannot call load_state_dict twice on the same AppState object. State dict has already been loaded." ) - cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank()) + cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank(), allow_partial_load=allow_partial_load) cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path) return raw_app_state diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index bb2a98dc0..18d3629ca 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -125,8 +125,9 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy: return parse_enum_by_name(name=name, enum_type=ShardingStrategy) -class DCPCheckpointLoadingConfig(BaseModel): - global_rank: Annotated[int, Field(strict=True, ge=0)] +# class DCPCheckpointLoadingConfig(BaseModel): +# global_rank: Annotated[int, Field(strict=True, ge=0)] +# allow_partial_load: bool = True class FSDP1CheckpointSavingConfig(BaseModel): @@ -389,6 +390,7 @@ class RawAppStateConfig(BaseModel): class DCPAppStateConfig(BaseModel): raw_app_state: PydanticAppStateType checkpoint_dir_path: Path + allow_partial_load: bool = False class PreTrainedHFTokenizerConfig(BaseModel): diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..5833ea728 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -13,7 +13,7 @@ SaveEveryKStepsCheckpointingStrategy, SaveKMostRecentCheckpointsStrategy, ) -from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading, FSDP1CheckpointLoading +from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import FSDP1CheckpointLoading from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving, FSDP1CheckpointSaving from modalities.checkpointing.stateful.app_state_factory import AppStateFactory from modalities.checkpointing.torch.torch_checkpoint_loading import TorchCheckpointLoading @@ -29,7 +29,6 @@ ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, DCPAppStateConfig, - DCPCheckpointLoadingConfig, DCPCheckpointSavingConfig, DebuggingEnrichedModelConfig, DistributedSamplerConfig, @@ -358,7 +357,7 @@ class ComponentEntity: ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig), # checkpoint loading ComponentEntity("checkpoint_loading", "fsdp1", FSDP1CheckpointLoading, FSDP1CheckpointLoadingConfig), - ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig), + # ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig), ComponentEntity("checkpoint_loading", "torch", TorchCheckpointLoading, TorchCheckpointLoadingConfig), # Progress subscriber ComponentEntity( From 33c55a43d3f5b5d6e0300334f636746ec2d69e8e Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Mon, 15 Jun 2026 11:50:30 +0200 Subject: [PATCH 4/5] feat: hardend weight tying against misconfigurations --- src/modalities/config/config.py | 8 +++ src/modalities/models/gpt2/gpt2_model.py | 6 ++ src/modalities/models/model.py | 5 ++ .../pipeline_parallelism_configs.py | 9 ++- src/modalities/models/weight_tying.py | 11 ++++ tests/test_weight_tying.py | 64 +++++++++++++++++++ 6 files changed, 102 insertions(+), 1 deletion(-) create mode 100644 src/modalities/models/weight_tying.py diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 18d3629ca..e67694b90 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -34,6 +34,7 @@ PydanticTokenizerIFType, ) from modalities.config.utils import parse_torch_device +from modalities.models.weight_tying import has_tied_word_embeddings from modalities.running_env.env_utils import ( FSDP2MixedPrecisionSettings, MixedPrecisionSettings, @@ -342,6 +343,13 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig": raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.") return self + @model_validator(mode="after") + def validate_untied_word_embeddings(self) -> "GPT2ModelTPConfig": + models = self.model if isinstance(self.model, list) else [self.model] + if any(has_tied_word_embeddings(model) for model in models): + raise ValueError("Tied word embeddings are not supported with Tensor Parallelism.") + return self + class CompiledModelConfig(BaseModel): model: PydanticPytorchModuleOrListType diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index 2da4979c0..eb8db53c2 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -938,6 +938,12 @@ def __init__( self.transformer.lm_head.weight ) # https://paperswithcode.com/method/weight-tying + @property + def has_tied_word_embeddings(self) -> bool: + token_embedding_weight = getattr(self.transformer.wte, "weight", None) + lm_head_weight = getattr(self.transformer.lm_head, "weight", None) + return token_embedding_weight is not None and token_embedding_weight is lm_head_weight + @overload def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index f981f6117..e949de2d8 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -46,6 +46,11 @@ def weight_decay_groups(self) -> WeightDecayGroups: """ return self._weight_decay_groups + @property + def has_tied_word_embeddings(self) -> bool: + """Whether the model currently uses tied token embedding and output weights.""" + return False + @abstractmethod def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py index ec16cdac0..b8576c80d 100644 --- a/src/modalities/models/parallelism/pipeline_parallelism_configs.py +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -1,6 +1,6 @@ from typing import Annotated -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from modalities.config.pydantic_if_types import ( PydanticDeviceMeshIFType, @@ -11,6 +11,7 @@ PydanticStagesGeneratorType, ) from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes +from modalities.models.weight_tying import has_tied_word_embeddings from modalities.utils.deprecated_alias import add_deprecated_alias @@ -26,6 +27,12 @@ class StagedPipelineConfig(BaseModel): pp_schedule_name: str num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] + @model_validator(mode="after") + def validate_untied_word_embeddings(self) -> "StagedPipelineConfig": + if has_tied_word_embeddings(self.whole_model): + raise ValueError("Tied word embeddings are not supported with Pipeline Parallelism.") + return self + class ScheduledPipelineConfig(BaseModel): loss_fn: PydanticLossIFType diff --git a/src/modalities/models/weight_tying.py b/src/modalities/models/weight_tying.py new file mode 100644 index 000000000..8f7b453ce --- /dev/null +++ b/src/modalities/models/weight_tying.py @@ -0,0 +1,11 @@ +import torch.nn as nn + + +def has_tied_word_embeddings(model: nn.Module) -> bool: + model_has_tied_word_embeddings = getattr(model, "has_tied_word_embeddings", None) + if model_has_tied_word_embeddings is None: + raise TypeError( + f"{type(model).__name__} must define 'has_tied_word_embeddings' to be used with tied-embedding validation." + ) + + return bool(model_has_tied_word_embeddings) diff --git a/tests/test_weight_tying.py b/tests/test_weight_tying.py index a9412cc4b..4eb81b1f3 100644 --- a/tests/test_weight_tying.py +++ b/tests/test_weight_tying.py @@ -1,6 +1,9 @@ import pytest import torch.nn as nn +from pydantic import ValidationError +from torch.distributed.device_mesh import DeviceMesh +from modalities.config.config import GPT2ModelTPConfig from modalities.models.components.layer_norms import LayerNormConfig from modalities.models.gpt2.gpt2_model import ( GPT2LLM, @@ -11,6 +14,10 @@ PositionTypes, ) from modalities.models.model import ActivationType +from modalities.models.parallelism.pipeline_parallelism_configs import StagedPipelineConfig +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator +from modalities.models.weight_tying import has_tied_word_embeddings +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees VOCAB_SIZE = 1000 EMBEDDING_DIM = 64 @@ -79,9 +86,17 @@ def create_gpt2_model(use_weight_tying: bool) -> GPT2LLM: ) +def create_device_mesh_stub(*mesh_dim_names: str) -> DeviceMesh: + device_mesh = DeviceMesh.__new__(DeviceMesh) + device_mesh.mesh_dim_names = mesh_dim_names + return device_mesh + + @pytest.mark.parametrize("use_weight_tying", [True, False]) def test_weight_tying_behavior(use_weight_tying): model = create_gpt2_model(use_weight_tying) + assert model.has_tied_word_embeddings is use_weight_tying + if use_weight_tying: assert ( model.transformer.wte.weight is model.transformer.lm_head.weight @@ -118,3 +133,52 @@ def test_weight_tying_named_parameters(use_weight_tying): assert ( "transformer.lm_head.weight" in named_params ), "transformer.lm_head.weight should appear in named_parameters when weight tying is not used." + + +def test_has_tied_word_embeddings_requires_model_capability(): + with pytest.raises(TypeError, match="must define 'has_tied_word_embeddings'"): + has_tied_word_embeddings(nn.Linear(1, 1)) + + +def test_tp_config_rejects_tied_word_embeddings(): + model = create_gpt2_model(use_weight_tying=True) + device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) + + with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Tensor Parallelism"): + GPT2ModelTPConfig(model=model, device_mesh=device_mesh) + + +def test_tp_config_allows_untied_word_embeddings(): + model = create_gpt2_model(use_weight_tying=False) + device_mesh = create_device_mesh_stub(ParallelismDegrees.TP.value) + + GPT2ModelTPConfig(model=model, device_mesh=device_mesh) + + +def test_pp_config_rejects_tied_word_embeddings(): + model = create_gpt2_model(use_weight_tying=True) + device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value) + + with pytest.raises(ValidationError, match="Tied word embeddings are not supported with Pipeline Parallelism"): + StagedPipelineConfig( + whole_model=model, + stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer), + device_mesh=device_mesh, + local_rank=0, + pp_schedule_name="gpipe", + num_layers_per_stage=1, + ) + + +def test_pp_config_allows_untied_word_embeddings(): + model = create_gpt2_model(use_weight_tying=False) + device_mesh = create_device_mesh_stub(ParallelismDegrees.PP.value) + + StagedPipelineConfig( + whole_model=model, + stages_generator=GPT2LLMStagesGenerator(num_model_layers=model.n_layer), + device_mesh=device_mesh, + local_rank=0, + pp_schedule_name="gpipe", + num_layers_per_stage=1, + ) From c1a86ab4180843350d0dcc18239a86b0bb83b19b Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 17 Jun 2026 18:13:42 +0200 Subject: [PATCH 5/5] fix: in case of weigh tying, we do not want to initialize the lm head weights separately from the input embedding weights, since they will be tied together and should share the same initialization. The lm head weights will be initialized as part of the input embedding weights initialization, so we can remove the separate initialization for the lm head weights when weight tying is enabled. --- .../models/gpt2/llama3_like_initialization.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index 5d1fa53af..5d9ce4d25 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -23,7 +23,7 @@ class Llama3Initializer(ModelInitializationIF): Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. """ - def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: + def __init__(self, num_layers: int, n_embd: int, depth_init: bool, use_weight_tying: bool) -> None: """ Initializes the Llama3Initializer. Args: @@ -39,16 +39,6 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: self.regex_to_init = { # embedding weights r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}), - # lm head weights - r"transformer\.lm_head\.weight": ( - trunc_normal_, - { - "mean": 0.0, - "std": 1 / math.sqrt(n_embd), - "a": -3 / math.sqrt(n_embd), - "b": 3 / math.sqrt(n_embd), - }, - ), # qkv projections r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": ( trunc_normal_, @@ -97,6 +87,17 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: }, ), } + if not use_weight_tying: + # lm head weights + self.regex_to_init[r"transformer\.lm_head\.weight"] = ( + trunc_normal_, + { + "mean": 0.0, + "std": 1 / math.sqrt(n_embd), + "a": -3 / math.sqrt(n_embd), + "b": 3 / math.sqrt(n_embd), + }, + ) def initialize_in_place(self, model: nn.Module): self._init_by_fqn_regex(model, self.regex_to_init)