From c79e0e46517992bc544853b8fe2c2037585f3d68 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 29 Jun 2026 16:07:37 +0200 Subject: [PATCH] Mixing `CoreModel` variants in `diff_models()` Allow mixing `CoreModel.__request__` and `CoreModel.__response__` in the `diff_models()` utility function. Before: ```python >>> diff_models(M.__request__(a=1), M.__response__(a=2)) Traceback (most recent call last): File "", line 1, in File "/src/dstack/_internal/core/services/diff.py", line 37, in diff_models raise TypeError("Both instances must be of the same Pydantic model class.") TypeError: Both instances must be of the same Pydantic model class. ``` After: ```python >>> diff_models(M.__request__(a=1), M.__response__(a=2)) {'a': ModelFieldDiffRequest(old=1, new=2)} ``` --- src/dstack/_internal/core/services/diff.py | 9 +- .../_internal/core/services/test_diff.py | 124 +++++++++++++++++- 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/core/services/diff.py b/src/dstack/_internal/core/services/diff.py index 321d97f5d4..154dee1db8 100644 --- a/src/dstack/_internal/core/services/diff.py +++ b/src/dstack/_internal/core/services/diff.py @@ -33,7 +33,14 @@ def diff_models( A dict of changed fields in the form of `{: {"old": old_value, "new": new_value}}` """ - if type(old) is not type(new): + if not ( + type(old) is type(new) + or ( + isinstance(old, CoreModel) + and isinstance(new, CoreModel) + and type(old).__response__ is type(new).__response__ + ) + ): raise TypeError("Both instances must be of the same Pydantic model class.") if reset is not None: diff --git a/src/tests/_internal/core/services/test_diff.py b/src/tests/_internal/core/services/test_diff.py index 4e8d355c0d..0279f64534 100644 --- a/src/tests/_internal/core/services/test_diff.py +++ b/src/tests/_internal/core/services/test_diff.py @@ -1,6 +1,128 @@ import pytest +from pydantic import BaseModel -from dstack._internal.core.services.diff import ModelDiff, ModelFieldDiff, flatten_diff_fields +from dstack._internal.core.models.common import CoreModel +from dstack._internal.core.services.diff import ( + ModelDiff, + ModelFieldDiff, + diff_models, + flatten_diff_fields, +) + + +class TestDiffModels: + class _BaseModelA(BaseModel): + a: int + b: str + + class _BaseModelB(BaseModel): + c: int + + class _BaseModelAB(_BaseModelA, _BaseModelB): + pass + + class _CoreModelA(CoreModel): + a: int + b: str + + class _CoreModelB(CoreModel): + c: int + + class _CoreModelAB(_CoreModelA, _CoreModelB): + pass + + @pytest.mark.parametrize( + ("old", "new", "expected"), + [ + pytest.param( + _BaseModelA(a=1, b="x"), + _BaseModelA(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="base-model", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelA(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model", + ), + pytest.param( + _BaseModelA(a=1, b="x"), + _BaseModelA(a=1, b="x"), + {}, + id="base-model-no-diff", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelA(a=1, b="x"), + {}, + id="core-model-no-diff", + ), + pytest.param( + _CoreModelA.__request__(a=1, b="x"), + _CoreModelA.__request__(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model-request", + ), + pytest.param( + _CoreModelA.__response__(a=1, b="x"), + _CoreModelA.__response__(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model-response", + ), + pytest.param( + _CoreModelA.__request__(a=1, b="x"), + _CoreModelA.__response__(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model-request-response", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelA.__response__(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model-base-request", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelA.__response__(a=1, b="y"), + {"b": ModelFieldDiff(old="x", new="y")}, + id="core-model-base-response", + ), + ], + ) + def test_diff_models(self, old: BaseModel, new: BaseModel, expected: ModelDiff) -> None: + assert diff_models(old, new) == expected + + @pytest.mark.parametrize( + ("old", "new"), + [ + pytest.param( + _BaseModelA(a=1, b="x"), + _BaseModelB(c=2), + id="different-base-models", + ), + pytest.param( + _BaseModelA(a=1, b="x"), + _BaseModelAB(a=1, b="x", c=2), + id="base-model-and-subclass", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelB(c=2), + id="different-core-models", + ), + pytest.param( + _CoreModelA(a=1, b="x"), + _CoreModelAB(a=1, b="x", c=2), + id="core-model-and-subclass", + ), + ], + ) + def test_type_mismatch(self, old: BaseModel, new: BaseModel) -> None: + with pytest.raises( + TypeError, match="Both instances must be of the same Pydantic model class." + ): + diff_models(old, new) @pytest.mark.parametrize(