diff --git a/rsconcept/backend/apps/oss/models/OperationSchema.py b/rsconcept/backend/apps/oss/models/OperationSchema.py index 39eeae94..86120a06 100644 --- a/rsconcept/backend/apps/oss/models/OperationSchema.py +++ b/rsconcept/backend/apps/oss/models/OperationSchema.py @@ -131,56 +131,73 @@ class OperationSchema: self.after_create_cst(list(rsform.constituents()), rsform) self.save(update_fields=['time_update']) - def set_arguments(self, operation: Operation, arguments: list[Operation]) -> None: + def set_arguments(self, target: int, arguments: list[Operation]) -> None: ''' Set arguments to operation. ''' + self.cache.ensure_loaded() + operation = self.cache.operation_by_id[target] processed: list[Operation] = [] - changed = False + deleted: list[Argument] = [] for current in operation.getArguments(): if current.argument not in arguments: - changed = True - current.delete() + deleted.append(current) else: processed.append(current.argument) + if len(deleted) > 0: + self.before_delete_arguments(operation, [x.argument for x in deleted]) + for deleted_arg in deleted: + self.cache.remove_argument(deleted_arg) + Argument.objects.filter(pk__in=[x.pk for x in deleted]).delete() + + added: list[Operation] = [] for arg in arguments: if arg not in processed: - changed = True processed.append(arg) - Argument.objects.create(operation=operation, argument=arg) - if not changed: - return - # TODO: trigger on_change effects - self.save(update_fields=['time_update']) + new_arg = Argument.objects.create(operation=operation, argument=arg) + self.cache.insert_argument(new_arg) + added.append(arg) + if len(added) > 0: + self.after_create_arguments(operation, added) + if len(added) > 0 or len(deleted) > 0: + self.save(update_fields=['time_update']) - def set_substitutions(self, target: Operation, substitutes: list[dict]) -> None: + def set_substitutions(self, target: int, substitutes: list[dict]) -> None: ''' Clear all arguments for operation. ''' + self.cache.ensure_loaded() + operation = self.cache.operation_by_id[target] + schema = self.cache.get_schema(operation) processed: list[dict] = [] - changed = False - - for current in target.getSubstitutions(): + deleted: list[Substitution] = [] + for current in operation.getSubstitutions(): subs = [ x for x in substitutes if x['original'] == current.original and x['substitution'] == current.substitution ] if len(subs) == 0: - changed = True - current.delete() + deleted.append(current) else: processed.append(subs[0]) + if len(deleted) > 0: + if schema is not None: + for sub in deleted: + self._undo_substitution(sub, schema) + else: + for sub in deleted: + self.cache.remove_substitution(sub) + Substitution.objects.filter(pk__in=[x.pk for x in deleted]).delete() - for sub in substitutes: - if sub not in processed: - changed = True - Substitution.objects.create( - operation=target, - original=sub['original'], - substitution=sub['substitution'] + added: list[Substitution] = [] + for sub_item in substitutes: + if sub_item not in processed: + new_sub = Substitution.objects.create( + operation=operation, + original=sub_item['original'], + substitution=sub_item['substitution'] ) + added.append(new_sub) + self._process_added_substitutions(added, schema) - if not changed: - return - # TODO: trigger on_change effects - - self.save(update_fields=['time_update']) + if len(added) > 0 or len(deleted) > 0: + self.save(update_fields=['time_update']) def create_input(self, operation: Operation) -> RSForm: ''' Create input RSForm. ''' @@ -242,7 +259,7 @@ class OperationSchema: def after_create_cst(self, cst_list: list[Constituenta], source: RSForm) -> None: ''' Trigger cascade resolutions when new constituent is created. ''' - self.cache.insert(source) + self.cache.insert_schema(source) inserted_aliases = [cst.alias for cst in cst_list] depend_aliases: set[str] = set() for new_cst in cst_list: @@ -258,13 +275,13 @@ class OperationSchema: def after_change_cst_type(self, target: Constituenta, source: RSForm) -> None: ''' Trigger cascade resolutions when constituenta type is changed. ''' - self.cache.insert(source) + self.cache.insert_schema(source) operation = self.cache.get_operation(source.model.pk) self._cascade_change_cst_type(target.pk, target.cst_type, operation.pk) def after_update_cst(self, target: Constituenta, data: dict, old_data: dict, source: RSForm) -> None: ''' Trigger cascade resolutions when constituenta data is changed. ''' - self.cache.insert(source) + self.cache.insert_schema(source) operation = self.cache.get_operation(source.model.pk) depend_aliases = self._extract_data_references(data, old_data) alias_mapping: CstMapping = {} @@ -282,16 +299,26 @@ class OperationSchema: def before_delete_cst(self, target: list[Constituenta], source: RSForm) -> None: ''' Trigger cascade resolutions before constituents are deleted. ''' - self.cache.insert(source) + self.cache.insert_schema(source) operation = self.cache.get_operation(source.model.pk) self._cascade_before_delete(target, operation.pk) def before_substitute(self, substitutions: CstSubstitution, source: RSForm) -> None: ''' Trigger cascade resolutions before constituents are substituted. ''' - self.cache.insert(source) + self.cache.insert_schema(source) operation = self.cache.get_operation(source.model.pk) self._cascade_before_substitute(substitutions, operation) + def before_delete_arguments(self, target: Operation, arguments: list[Operation]) -> None: + ''' Trigger cascade resolutions before arguments are deleted. ''' + pass + + def after_create_arguments(self, target: Operation, arguments: list[Operation]) -> None: + ''' Trigger cascade resolutions after arguments are created. ''' + schema = self.cache.get_schema(target) + if schema is None: + return + def _cascade_create_cst(self, cst_list: list[Constituenta], operation: Operation, mapping: CstMapping) -> None: children = self.cache.graph.outputs[operation.pk] if len(children) == 0: @@ -411,12 +438,12 @@ class OperationSchema: self, mapping: CstMapping, target: list[int], - operation: Operation, + operation: int, schema: RSForm ) -> None: alias_mapping = OperationSchema._produce_alias_mapping(mapping) schema.apply_partial_mapping(alias_mapping, target) - children = self.cache.graph.outputs[operation.pk] + children = self.cache.graph.outputs[operation] if len(children) == 0: return self.cache.ensure_loaded() @@ -431,7 +458,7 @@ class OperationSchema: new_target = self.cache.get_inheritors_list(target, child_id) if len(new_target) == 0: continue - self._cascade_partial_mapping(new_mapping, new_target, child_operation, child_schema) + self._cascade_partial_mapping(new_mapping, new_target, child_id, child_schema) @staticmethod def _produce_alias_mapping(mapping: CstMapping) -> dict[str, str]: @@ -557,35 +584,64 @@ class OperationSchema: for sub in to_process: self._undo_substitution(sub, schema, target_ids) - def _undo_substitution(self, target: Substitution, schema: RSForm, ignore_parents: list[int]) -> None: - operation = self.cache.operation_by_id[target.operation_id] + def _undo_substitution( + self, + target: Substitution, + schema: RSForm, + ignore_parents: Optional[list[int]] = None + ) -> None: + if ignore_parents is None: + ignore_parents = [] + operation_id = target.operation_id original_schema, _, original_cst, substitution_cst = self.cache.unfold_sub(target) dependant = [] for cst_id in original_schema.get_dependant([original_cst.pk]): if cst_id not in ignore_parents: - inheritor_id = self.cache.get_inheritor(cst_id, operation.pk) + inheritor_id = self.cache.get_inheritor(cst_id, operation_id) if inheritor_id is not None: dependant.append(inheritor_id) - self.cache.substitutions[operation.pk].remove(target) + self.cache.substitutions[operation_id].remove(target) target.delete() new_original: Optional[Constituenta] = None if original_cst.pk not in ignore_parents: full_cst = Constituenta.objects.get(pk=original_cst.pk) self.after_create_cst([full_cst], original_schema) - new_original_id = self.cache.get_inheritor(original_cst.pk, operation.pk) + new_original_id = self.cache.get_inheritor(original_cst.pk, operation_id) assert new_original_id is not None new_original = schema.cache.by_id[new_original_id] if len(dependant) == 0: return - substitution_id = self.cache.get_inheritor(substitution_cst.pk, operation.pk) + substitution_id = self.cache.get_inheritor(substitution_cst.pk, operation_id) assert substitution_id is not None substitution_inheritor = schema.cache.by_id[substitution_id] mapping = {cast(str, substitution_inheritor.alias): new_original} - self._cascade_partial_mapping(mapping, dependant, operation, schema) + self._cascade_partial_mapping(mapping, dependant, operation_id, schema) + + def _process_added_substitutions(self, added: list[Substitution], schema: Optional[RSForm]) -> None: + if len(added) == 0: + return + if schema is None: + for sub in added: + self.cache.insert_substitution(sub) + return + + cst_mapping: CstSubstitution = [] + for sub in added: + original_id = self.cache.get_inheritor(sub.original_id, sub.operation_id) + substitution_id = self.cache.get_inheritor(sub.substitution_id, sub.operation_id) + if original_id is None or substitution_id is None: + raise ValueError('Substitutions not found.') + original_cst = schema.cache.by_id[original_id] + substitution_cst = schema.cache.by_id[substitution_id] + cst_mapping.append((original_cst, substitution_cst)) + self.before_substitute(cst_mapping, schema) + schema.substitute(cst_mapping) + for sub in added: + self.cache.insert_substitution(sub) class OssCache: @@ -608,11 +664,18 @@ class OssCache: self.substitutions: dict[int, list[Substitution]] = {} self.inheritance: dict[int, list[Inheritance]] = {} - def insert(self, schema: RSForm) -> None: - ''' Insert new schema. ''' - if not self._schema_by_id.get(schema.model.pk): - schema.cache.ensure_loaded() - self._insert_new(schema) + def ensure_loaded(self) -> None: + ''' Ensure cache is fully loaded. ''' + if self.is_loaded: + return + self.is_loaded = True + for operation in self.operations: + self.inheritance[operation.pk] = [] + self.substitutions[operation.pk] = [] + for sub in self._oss.substitutions().only('operation_id', 'original_id', 'substitution_id'): + self.substitutions[sub.operation_id].append(sub) + for item in self._oss.inheritance().only('operation_id', 'parent_id', 'child_id'): + self.inheritance[item.operation_id].append(item) def get_schema(self, operation: Operation) -> Optional[RSForm]: ''' Get schema by Operation. ''' @@ -633,19 +696,6 @@ class OssCache: return operation raise ValueError(f'Operation for schema {schema} not found') - def ensure_loaded(self) -> None: - ''' Ensure cache is fully loaded. ''' - if self.is_loaded: - return - self.is_loaded = True - for operation in self.operations: - self.inheritance[operation.pk] = [] - self.substitutions[operation.pk] = [] - for sub in self._oss.substitutions().only('operation_id', 'original_id', 'substitution_id'): - self.substitutions[sub.operation_id].append(sub) - for item in self._oss.inheritance().only('operation_id', 'parent_id', 'child_id'): - self.inheritance[item.operation_id].append(item) - def get_inheritor(self, parent_cst: int, operation: int) -> Optional[int]: ''' Get child for parent inside target RSFrom. ''' for item in self.inheritance[operation]: @@ -668,6 +718,12 @@ class OssCache: return self.get_inheritor(sub.substitution_id, operation) return self.get_inheritor(parent_cst, operation) + def insert_schema(self, schema: RSForm) -> None: + ''' Insert new schema. ''' + if not self._schema_by_id.get(schema.model.pk): + schema.cache.ensure_loaded() + self._insert_new(schema) + def insert_operation(self, operation: Operation) -> None: ''' Insert new operation. ''' self.operations.append(operation) @@ -677,6 +733,10 @@ class OssCache: self.substitutions[operation.pk] = [] self.inheritance[operation.pk] = [] + def insert_argument(self, argument: Argument) -> None: + ''' Insert new argument. ''' + self.graph.add_edge(argument.operation_id, argument.argument_id) + def insert_inheritance(self, inheritance: Inheritance) -> None: ''' Insert new inheritance. ''' self.inheritance[inheritance.operation_id].append(inheritance) @@ -697,9 +757,15 @@ class OssCache: for item in inherit_to_delete: self.inheritance[operation].remove(item) + def remove_schema(self, schema: RSForm) -> None: + ''' Remove schema from cache. ''' + self._schemas.remove(schema) + del self._schema_by_id[schema.model.pk] + def remove_operation(self, operation: int) -> None: ''' Remove operation from cache. ''' target = self.operation_by_id[operation] + self.graph.remove_node(operation) if target.result_id in self._schema_by_id: self._schemas.remove(self._schema_by_id[target.result_id]) del self._schema_by_id[target.result_id] @@ -709,10 +775,13 @@ class OssCache: del self.substitutions[operation] del self.inheritance[operation] - def remove_schema(self, schema: RSForm) -> None: - ''' Remove schema from cache. ''' - self._schemas.remove(schema) - del self._schema_by_id[schema.model.pk] + def remove_argument(self, argument: Argument) -> None: + ''' Remove argument from cache. ''' + self.graph.remove_edge(argument.operation_id, argument.argument_id) + + def remove_substitution(self, target: Substitution) -> None: + ''' Remove substitution from cache. ''' + self.substitutions[target.operation_id].remove(target) def unfold_sub(self, sub: Substitution) -> tuple[RSForm, RSForm, Constituenta, Constituenta]: operation = self.operation_by_id[sub.operation_id] diff --git a/rsconcept/backend/apps/oss/tests/s_propagation/t_constituents.py b/rsconcept/backend/apps/oss/tests/s_propagation/t_constituents.py index 0d9ec899..a7b9f559 100644 --- a/rsconcept/backend/apps/oss/tests/s_propagation/t_constituents.py +++ b/rsconcept/backend/apps/oss/tests/s_propagation/t_constituents.py @@ -51,7 +51,7 @@ class TestChangeConstituents(EndpointTester): alias='3', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation3, [self.operation1, self.operation2]) + self.owned.set_arguments(self.operation3.pk, [self.operation1, self.operation2]) self.owned.execute_operation(self.operation3) self.operation3.refresh_from_db() self.ks3 = RSForm(self.operation3.result) diff --git a/rsconcept/backend/apps/oss/tests/s_propagation/t_operations.py b/rsconcept/backend/apps/oss/tests/s_propagation/t_operations.py index 987030b3..b95c5319 100644 --- a/rsconcept/backend/apps/oss/tests/s_propagation/t_operations.py +++ b/rsconcept/backend/apps/oss/tests/s_propagation/t_operations.py @@ -71,8 +71,8 @@ class TestChangeOperations(EndpointTester): alias='4', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation4, [self.operation1, self.operation2]) - self.owned.set_substitutions(self.operation4, [{ + self.owned.set_arguments(self.operation4.pk, [self.operation1, self.operation2]) + self.owned.set_substitutions(self.operation4.pk, [{ 'original': self.ks1X1, 'substitution': self.ks2S1 }]) @@ -92,8 +92,8 @@ class TestChangeOperations(EndpointTester): alias='5', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation5, [self.operation4, self.operation3]) - self.owned.set_substitutions(self.operation5, [{ + self.owned.set_arguments(self.operation5.pk, [self.operation4, self.operation3]) + self.owned.set_substitutions(self.operation5.pk, [{ 'original': self.ks4X1, 'substitution': self.ks3X1 }]) @@ -249,3 +249,38 @@ class TestChangeOperations(EndpointTester): self.assertEqual(self.ks5.constituents().count(), 8) self.assertEqual(self.ks4D2.definition_formal, r'X1 X2 X3 S1 D1') self.assertEqual(self.ks5D4.definition_formal, r'X1 X2 X3 S1 D1 D2 D3') + + @decl_endpoint('/api/oss/{item}/update-operation', method='patch') + def test_change_substitutions(self): + data = { + 'target': self.operation4.pk, + 'item_data': { + 'alias': 'Test4 mod', + 'title': 'Test title mod', + 'comment': 'Comment mod' + }, + 'positions': [], + 'arguments': [self.operation1.pk, self.operation2.pk], + 'substitutions': [ + { + 'original': self.ks1X1.pk, + 'substitution': self.ks2X2.pk + }, + { + 'original': self.ks2X1.pk, + 'substitution': self.ks1D1.pk + } + ] + } + + self.executeOK(data=data, item=self.owned_id) + self.ks4D2.refresh_from_db() + self.ks5D4.refresh_from_db() + subs1_2 = self.operation4.getSubstitutions() + self.assertEqual(subs1_2.count(), 2) + subs3_4 = self.operation5.getSubstitutions() + self.assertEqual(subs3_4.count(), 1) + self.assertEqual(self.ks4.constituents().count(), 5) + self.assertEqual(self.ks5.constituents().count(), 7) + self.assertEqual(self.ks4D2.definition_formal, r'X1 D1 X3 S1 D1') + self.assertEqual(self.ks5D4.definition_formal, r'X1 D2 X3 S1 D1 D2 D3') diff --git a/rsconcept/backend/apps/oss/tests/s_propagation/t_substitutions.py b/rsconcept/backend/apps/oss/tests/s_propagation/t_substitutions.py index 7cc5ae26..1a383a42 100644 --- a/rsconcept/backend/apps/oss/tests/s_propagation/t_substitutions.py +++ b/rsconcept/backend/apps/oss/tests/s_propagation/t_substitutions.py @@ -71,8 +71,8 @@ class TestChangeSubstitutions(EndpointTester): alias='4', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation4, [self.operation1, self.operation2]) - self.owned.set_substitutions(self.operation4, [{ + self.owned.set_arguments(self.operation4.pk, [self.operation1, self.operation2]) + self.owned.set_substitutions(self.operation4.pk, [{ 'original': self.ks1X1, 'substitution': self.ks2S1 }]) @@ -92,8 +92,8 @@ class TestChangeSubstitutions(EndpointTester): alias='5', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation5, [self.operation4, self.operation3]) - self.owned.set_substitutions(self.operation5, [{ + self.owned.set_arguments(self.operation5.pk, [self.operation4, self.operation3]) + self.owned.set_substitutions(self.operation5.pk, [{ 'original': self.ks4X1, 'substitution': self.ks3X1 }]) diff --git a/rsconcept/backend/apps/oss/tests/s_views/t_oss.py b/rsconcept/backend/apps/oss/tests/s_views/t_oss.py index 6937d9bb..ba52d035 100644 --- a/rsconcept/backend/apps/oss/tests/s_views/t_oss.py +++ b/rsconcept/backend/apps/oss/tests/s_views/t_oss.py @@ -55,8 +55,8 @@ class TestOssViewset(EndpointTester): alias='3', operation_type=OperationType.SYNTHESIS ) - self.owned.set_arguments(self.operation3, [self.operation1, self.operation2]) - self.owned.set_substitutions(self.operation3, [{ + self.owned.set_arguments(self.operation3.pk, [self.operation1, self.operation2]) + self.owned.set_substitutions(self.operation3.pk, [{ 'original': self.ks1X1, 'substitution': self.ks2X1 }]) diff --git a/rsconcept/backend/apps/oss/views/oss.py b/rsconcept/backend/apps/oss/views/oss.py index 6a4a4dc5..a80d8de9 100644 --- a/rsconcept/backend/apps/oss/views/oss.py +++ b/rsconcept/backend/apps/oss/views/oss.py @@ -129,7 +129,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev oss.create_input(new_operation) if new_operation.operation_type != m.OperationType.INPUT and 'arguments' in serializer.validated_data: oss.set_arguments( - operation=new_operation, + target=new_operation.pk, arguments=serializer.validated_data['arguments'] ) return Response( @@ -305,9 +305,9 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev operation.result.comment = operation.comment operation.result.save() if 'arguments' in serializer.validated_data: - oss.set_arguments(operation, serializer.validated_data['arguments']) + oss.set_arguments(operation.pk, serializer.validated_data['arguments']) if 'substitutions' in serializer.validated_data: - oss.set_substitutions(operation, serializer.validated_data['substitutions']) + oss.set_substitutions(operation.pk, serializer.validated_data['substitutions']) return Response( status=c.HTTP_200_OK, data=s.OperationSchemaSerializer(oss.model).data diff --git a/rsconcept/backend/apps/rsform/graph.py b/rsconcept/backend/apps/rsform/graph.py index b6a749ad..bdb6ce6e 100644 --- a/rsconcept/backend/apps/rsform/graph.py +++ b/rsconcept/backend/apps/rsform/graph.py @@ -42,6 +42,28 @@ class Graph(Generic[ItemType]): if src not in self.inputs[dest]: self.inputs[dest].append(src) + def remove_edge(self, src: ItemType, dest: ItemType): + ''' Remove edge from graph. ''' + if not self.contains(src) or not self.contains(dest): + return + if dest in self.outputs[src]: + self.outputs[src].remove(dest) + if src in self.inputs[dest]: + self.inputs[dest].remove(src) + + def remove_node(self, target: ItemType): + ''' Remove node from graph. ''' + if not self.contains(target): + return + del self.outputs[target] + del self.inputs[target] + for list_out in self.outputs.values(): + if target in list_out: + list_out.remove(target) + for list_in in self.inputs.values(): + if target in list_in: + list_in.remove(target) + def expand_inputs(self, origin: Iterable[ItemType]) -> list[ItemType]: ''' Expand origin nodes forward through graph edges. ''' result: list[ItemType] = [] diff --git a/rsconcept/backend/apps/rsform/tests/t_graph.py b/rsconcept/backend/apps/rsform/tests/t_graph.py index 63713f95..3c99462a 100644 --- a/rsconcept/backend/apps/rsform/tests/t_graph.py +++ b/rsconcept/backend/apps/rsform/tests/t_graph.py @@ -26,6 +26,32 @@ class TestGraph(unittest.TestCase): self.assertTrue(graph.has_edge(1, 3)) self.assertTrue(graph.has_edge(2, 1)) + def test_remove_node(self): + graph = Graph({ + 1: [2], + 2: [3, 5], + 3: [], + 5: [] + }) + self.assertEqual(len(graph.outputs), 4) + graph.remove_node(0) + graph.remove_node(2) + self.assertEqual(graph.outputs[1], []) + self.assertEqual(len(graph.outputs), 3) + + def test_remove_edge(self): + graph = Graph({ + 1: [2], + 2: [3, 5], + 3: [], + 5: [] + }) + graph.remove_edge(0, 1) + graph.remove_edge(2, 1) + self.assertEqual(graph.outputs[1], [2]) + graph.remove_edge(1, 2) + self.assertEqual(graph.outputs[1], []) + graph.remove_edge(1, 2) def test_expand_outputs(self): graph = Graph({