diff --git a/tests/test_core.py b/tests/test_core.py index 33fd53845..3daedc5b7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -49,6 +49,20 @@ class TestStateProxy: sp = StateProxy(raw_state_dict) + @classmethod + def count_initial_mutations(cls, d, count=0): + """ + Counts the number of mutations that will be performed for a given dictionary + when it is converted into a StateProxy. + """ + for key, value in d.items(): + if not key.startswith('_'): + count += 1 # Increment for each key-value pair + if isinstance(value, dict): + count = TestStateProxy.count_initial_mutations(value, count) + # Recurse for nested dictionaries + return count + def test_read(self) -> None: d = self.sp.to_dict() assert d.get("name") == "Robert" @@ -57,10 +71,31 @@ def test_read(self) -> None: assert d.get("utfࠀ") == 23 def test_mutations(self) -> None: + m = self.sp.get_mutations_as_dict() + assert len(m) == TestStateProxy.count_initial_mutations(self.sp.to_dict()) + # Mutated after initialization from raw_state_dict + self.sp["age"] = 2 - self.sp["interests"].append("dogs") + m = self.sp.get_mutations_as_dict() + assert m.get("age") == 2 + assert len(m) == 1 + + self.sp["interests"] += ["dogs"] self.sp["features"]["height"] = "short" + m = self.sp.get_mutations_as_dict() + assert m.get("interests") == ["lamps", "cars", "dogs"] + assert m.get("features.height") == "short" + assert len(m) == 2 + self.sp["state.with.dots"]["photo.jpeg"] = "Corrupted" + m = self.sp.get_mutations_as_dict() + assert m.get("state\\.with\\.dots.photo\\.jpeg") == "Corrupted" + assert len(m) == 1 + + self.sp["new.state.with.dots"] = {"test": "test"} + m = self.sp.get_mutations_as_dict() + assert len(m) == 1 + d = self.sp.to_dict() assert d.get("age") == 2 assert d.get("interests") == ["lamps", "cars", "dogs"] @@ -70,8 +105,7 @@ def test_mutations(self) -> None: self.sp.apply("age") m = self.sp.get_mutations_as_dict() assert m.get("age") == 2 - assert m.get("features.height") == "short" - assert m.get("state\\.with\\.dots.photo\\.jpeg") == "Corrupted" + assert len(m) == 1 def test_private_members(self) -> None: d = self.sp.to_dict()