M: Implement substitution updates on cst_delete

This commit is contained in:
Ivan 2024-08-13 23:38:19 +03:00
parent cb3fd32e78
commit 5210c6b811
4 changed files with 227 additions and 71 deletions

View File

@ -7,6 +7,7 @@ from rest_framework.serializers import ValidationError
from apps.library.models import LibraryItem
from apps.rsform.graph import Graph
from apps.rsform.models import (
DELETED_ALIAS,
INSERT_LAST,
Constituenta,
CstType,
@ -21,7 +22,7 @@ from .Operation import Operation, OperationType
from .OperationSchema import OperationSchema
from .Substitution import Substitution
CstMapping = dict[str, Constituenta]
CstMapping = dict[str, Optional[Constituenta]]
CstSubstitution = list[tuple[Constituenta, Constituenta]]
@ -78,27 +79,6 @@ class ChangeManager:
operation = self.cache.get_operation(source)
self._cascade_before_substitute(substitutions, operation)
def _cascade_before_substitute(
self,
substitutions: CstSubstitution,
operation: Operation
) -> None:
children = self.cache.graph.outputs[operation.pk]
if len(children) == 0:
return
self.cache.ensure_loaded()
for child_id in children:
child_operation = self.cache.operation_by_id[child_id]
child_schema = self.cache.get_schema(child_operation)
if child_schema is None:
continue
child_schema.cache.ensure_loaded()
new_substitutions = self._transform_substitutions(substitutions, child_operation, child_schema)
if len(new_substitutions) == 0:
continue
self._cascade_before_substitute(new_substitutions, child_operation)
child_schema.substitute(new_substitutions)
def _cascade_create_cst(self, cst_list: list[Constituenta], operation: Operation, mapping: CstMapping) -> None:
children = self.cache.graph.outputs[operation.pk]
if len(children) == 0:
@ -115,7 +95,7 @@ class ChangeManager:
self.cache.ensure_loaded()
new_mapping = self._transform_mapping(mapping, child_operation, child_schema)
alias_mapping = {alias: cst.alias for alias, cst in new_mapping.items()}
alias_mapping = ChangeManager._produce_alias_mapping(new_mapping)
insert_where = self._determine_insert_position(cst_list[0], child_operation, source_schema, child_schema)
new_cst_list = child_schema.insert_copy(cst_list, insert_where, alias_mapping)
for index, cst in enumerate(new_cst_list):
@ -161,7 +141,7 @@ class ChangeManager:
child_schema = self.cache.get_schema(child_operation)
assert child_schema is not None
new_mapping = self._transform_mapping(mapping, child_operation, child_schema)
alias_mapping = {alias: cst.alias for alias, cst in new_mapping.items()}
alias_mapping = ChangeManager._produce_alias_mapping(new_mapping)
successor = child_schema.cache.by_id.get(successor_id)
if successor is None:
continue
@ -184,26 +164,74 @@ class ChangeManager:
child_schema = self.cache.get_schema(child_operation)
if child_schema is None:
continue
child_schema.cache.ensure_loaded()
# TODO: check if substitutions are affected. Undo substitutions before deletion
child_target_cst = []
child_target_ids = []
for cst in target:
successor_id = self.cache.get_inheritor(cst.pk, child_id)
if successor_id is not None:
child_target_ids.append(successor_id)
child_target_cst.append(child_schema.cache.by_id[successor_id])
self._undo_substitutions_cst(target, child_operation, child_schema)
child_target_ids = self.cache.get_inheritors_list([cst.pk for cst in target], child_id)
child_target_cst = [child_schema.cache.by_id[cst_id] for cst_id in child_target_ids]
self._cascade_before_delete(child_target_cst, child_operation)
self.cache.remove_cst(child_target_ids, child_id)
child_schema.delete_cst(child_target_cst)
if len(child_target_cst) > 0:
self.cache.remove_cst(child_target_ids, child_id)
child_schema.delete_cst(child_target_cst)
def _cascade_before_substitute(self, substitutions: CstSubstitution, operation: Operation) -> None:
children = self.cache.graph.outputs[operation.pk]
if len(children) == 0:
return
self.cache.ensure_loaded()
for child_id in children:
child_operation = self.cache.operation_by_id[child_id]
child_schema = self.cache.get_schema(child_operation)
if child_schema is None:
continue
new_substitutions = self._transform_substitutions(substitutions, child_id, child_schema)
if len(new_substitutions) == 0:
continue
self._cascade_before_substitute(new_substitutions, child_operation)
child_schema.substitute(new_substitutions)
def _cascade_partial_mapping(
self,
mapping: CstMapping,
target: list[int],
operation: Operation,
schema: RSForm
) -> None:
alias_mapping = ChangeManager._produce_alias_mapping(mapping)
schema.apply_partial_mapping(alias_mapping, target)
children = self.cache.graph.outputs[operation.pk]
if len(children) == 0:
return
self.cache.ensure_loaded()
for child_id in children:
child_operation = self.cache.operation_by_id[child_id]
child_schema = self.cache.get_schema(child_operation)
if child_schema is None:
continue
new_mapping = self._transform_mapping(mapping, child_operation, child_schema)
if not new_mapping:
continue
new_target = self.cache.get_inheritors_list(target, child_id)
if len(new_target) == 0:
continue
self._cascade_partial_mapping(new_mapping, new_target, child_operation, child_schema)
@staticmethod
def _produce_alias_mapping(mapping: CstMapping) -> dict[str, str]:
result: dict[str, str] = {}
for alias, cst in mapping.items():
if cst is None:
result[alias] = DELETED_ALIAS
else:
result[alias] = cst.alias
return result
def _transform_mapping(self, mapping: CstMapping, operation: Operation, schema: RSForm) -> CstMapping:
if len(mapping) == 0:
return mapping
result: CstMapping = {}
for alias, cst in mapping.items():
if cst is None:
result[alias] = None
continue
successor_id = self.cache.get_successor(cst.pk, operation.pk)
if successor_id is None:
continue
@ -263,36 +291,36 @@ class ChangeManager:
def _transform_substitutions(
self,
target: CstSubstitution,
operation: Operation,
operation: int,
schema: RSForm
) -> CstSubstitution:
result: CstSubstitution = []
for current_sub in target:
sub_replaced = False
new_substitution_id = self.cache.get_inheritor(current_sub[1].pk, operation.pk)
new_substitution_id = self.cache.get_inheritor(current_sub[1].pk, operation)
if new_substitution_id is None:
for sub in self.cache.substitutions[operation.pk]:
for sub in self.cache.substitutions[operation]:
if sub.original_id == current_sub[1].pk:
sub_replaced = True
new_substitution_id = self.cache.get_inheritor(sub.original_id, operation.pk)
new_substitution_id = self.cache.get_inheritor(sub.original_id, operation)
break
new_original_id = self.cache.get_inheritor(current_sub[0].pk, operation.pk)
new_original_id = self.cache.get_inheritor(current_sub[0].pk, operation)
original_replaced = False
if new_original_id is None:
for sub in self.cache.substitutions[operation.pk]:
for sub in self.cache.substitutions[operation]:
if sub.original_id == current_sub[0].pk:
original_replaced = True
sub.original_id = current_sub[1].pk
sub.save()
new_original_id = new_substitution_id
new_substitution_id = self.cache.get_inheritor(sub.substitution_id, operation.pk)
new_substitution_id = self.cache.get_inheritor(sub.substitution_id, operation)
break
if sub_replaced and original_replaced:
raise ValidationError({'propagation': 'Substitution breaks OSS substitutions.'})
for sub in self.cache.substitutions[operation.pk]:
for sub in self.cache.substitutions[operation]:
if sub.substitution_id == current_sub[0].pk:
sub.substitution_id = current_sub[1].pk
sub.save()
@ -301,6 +329,45 @@ class ChangeManager:
result.append((schema.cache.by_id[new_original_id], schema.cache.by_id[new_substitution_id]))
return result
def _undo_substitutions_cst(self, target: list[Constituenta], operation: Operation, schema: RSForm) -> None:
target_ids = [cst.pk for cst in target]
to_process = []
for sub in self.cache.substitutions[operation.pk]:
if sub.original_id in target_ids or sub.substitution_id in target_ids:
to_process.append(sub)
for sub in to_process:
self._undo_substitution(sub, schema, target_ids)
def _undo_substitution(self, target: Substitution, schema: RSForm, ignore_parents: list[int]) -> None:
operation = self.cache.operation_by_id[target.operation_id]
original_schema, _, original_cst, substitution_cst = self.cache.unfold_sub(target)
dependant = []
for cst_id in original_schema.get_dependant([original_cst.pk]):
if cst_id not in ignore_parents:
inheritor_id = self.cache.get_inheritor(cst_id, operation.pk)
if inheritor_id is not None:
dependant.append(inheritor_id)
self.cache.substitutions[operation.pk].remove(target)
target.delete()
new_original: Optional[Constituenta] = None
if original_cst.pk not in ignore_parents:
full_cst = Constituenta.objects.get(pk=original_cst.pk)
self.after_create_cst([full_cst], original_schema)
new_original_id = self.cache.get_inheritor(original_cst.pk, operation.pk)
assert new_original_id is not None
new_original = schema.cache.by_id[new_original_id]
if len(dependant) == 0:
return
substitution_id = self.cache.get_inheritor(substitution_cst.pk, operation.pk)
assert substitution_id is not None
substitution_inheritor = schema.cache.by_id[substitution_id]
mapping = {cast(str, substitution_inheritor.alias): new_original}
self._cascade_partial_mapping(mapping, dependant, operation, schema)
class OssCache:
''' Cache for OSS data. '''
@ -325,6 +392,7 @@ class OssCache:
def insert(self, schema: RSForm) -> None:
''' Insert new schema. '''
if not self._schema_by_id.get(schema.model.pk):
schema.cache.ensure_loaded()
self._insert_new(schema)
def get_schema(self, operation: Operation) -> Optional[RSForm]:
@ -367,6 +435,14 @@ class OssCache:
return item.child_id
return None
def get_inheritors_list(self, target: list[int], operation: int) -> list[int]:
''' Get child for parent inside target RSFrom. '''
result = []
for item in self.inheritance[operation]:
if item.parent_id in target:
result.append(item.child_id)
return result
def get_successor(self, parent_cst: int, operation: int) -> Optional[int]:
''' Get child for parent inside target RSFrom including substitutions. '''
for sub in self.substitutions[operation]:
@ -395,6 +471,27 @@ class OssCache:
for item in inherit_to_delete:
self.inheritance[operation].remove(item)
def unfold_sub(self, sub: Substitution) -> tuple[RSForm, RSForm, Constituenta, Constituenta]:
operation = self.operation_by_id[sub.operation_id]
parents = self.graph.inputs[operation.pk]
original_cst = None
substitution_cst = None
original_schema = None
substitution_schema = None
for parent_id in parents:
parent_schema = self.get_schema(self.operation_by_id[parent_id])
if parent_schema is None:
continue
if sub.original_id in parent_schema.cache.by_id:
original_schema = parent_schema
original_cst = original_schema.cache.by_id[sub.original_id]
if sub.substitution_id in parent_schema.cache.by_id:
substitution_schema = parent_schema
substitution_cst = substitution_schema.cache.by_id[sub.substitution_id]
if original_schema is None or substitution_schema is None or original_cst is None or substitution_cst is None:
raise ValueError(f'Parent schema for Substitution-{sub.pk} not found.')
return original_schema, substitution_schema, original_cst, substitution_cst
def _insert_new(self, schema: RSForm) -> None:
self._schemas.append(schema)
self._schema_by_id[schema.model.pk] = schema

View File

@ -13,36 +13,36 @@ def _get_oss_hosts(item: LibraryItem) -> list[LibraryItem]:
class PropagationFacade:
''' Change propagation API. '''
@classmethod
def after_create_cst(cls, new_cst: list[Constituenta], source: RSForm) -> None:
@staticmethod
def after_create_cst(new_cst: list[Constituenta], source: RSForm) -> None:
''' Trigger cascade resolutions when new constituent is created. '''
hosts = _get_oss_hosts(source.model)
for host in hosts:
ChangeManager(host).after_create_cst(new_cst, source)
@classmethod
def after_change_cst_type(cls, target: Constituenta, source: RSForm) -> None:
@staticmethod
def after_change_cst_type(target: Constituenta, source: RSForm) -> None:
''' Trigger cascade resolutions when constituenta type is changed. '''
hosts = _get_oss_hosts(source.model)
for host in hosts:
ChangeManager(host).after_change_cst_type(target, source)
@classmethod
def after_update_cst(cls, target: Constituenta, data: dict, old_data: dict, source: RSForm) -> None:
@staticmethod
def after_update_cst(target: Constituenta, data: dict, old_data: dict, source: RSForm) -> None:
''' Trigger cascade resolutions when constituenta data is changed. '''
hosts = _get_oss_hosts(source.model)
for host in hosts:
ChangeManager(host).after_update_cst(target, data, old_data, source)
@classmethod
def before_delete(cls, target: list[Constituenta], source: RSForm) -> None:
@staticmethod
def before_delete(target: list[Constituenta], source: RSForm) -> None:
''' Trigger cascade resolutions before constituents are deleted. '''
hosts = _get_oss_hosts(source.model)
for host in hosts:
ChangeManager(host).before_delete(target, source)
@classmethod
def before_substitute(cls, substitutions: list[tuple[Constituenta, Constituenta]], source: RSForm) -> None:
@staticmethod
def before_substitute(substitutions: list[tuple[Constituenta, Constituenta]], source: RSForm) -> None:
''' Trigger cascade resolutions before constituents are substituted. '''
hosts = _get_oss_hosts(source.model)
for host in hosts:

View File

@ -158,3 +158,33 @@ class TestChangeSubstitutions(EndpointTester):
self.assertEqual(self.ks4D1.definition_formal, r'X2 X1')
self.assertEqual(self.ks4D2.definition_formal, r'X1 X2 X3 X2 D1')
self.assertEqual(self.ks5D4.definition_formal, r'X1 X2 X3 X2 D1 D2 D3')
@decl_endpoint('/api/rsforms/{schema}/delete-multiple-cst', method='patch')
def test_delete_original(self):
data = {'items': [self.ks1X1.pk, self.ks1D1.pk]}
self.executeOK(data=data, schema=self.ks1.model.pk)
self.ks4D2.refresh_from_db()
self.ks5D4.refresh_from_db()
subs1_2 = self.operation4.getSubstitutions()
self.assertEqual(subs1_2.count(), 0)
subs3_4 = self.operation5.getSubstitutions()
self.assertEqual(subs3_4.count(), 1)
self.assertEqual(self.ks5.constituents().count(), 7)
self.assertEqual(self.ks4D2.definition_formal, r'X1 X2 X3 S1 DEL')
self.assertEqual(self.ks5D4.definition_formal, r'X1 X2 X3 S1 D1 DEL D3')
@decl_endpoint('/api/rsforms/{schema}/delete-multiple-cst', method='patch')
def test_delete_substitution(self):
data = {'items': [self.ks2S1.pk, self.ks2X2.pk]}
self.executeOK(data=data, schema=self.ks2.model.pk)
self.ks4D1.refresh_from_db()
self.ks4D2.refresh_from_db()
self.ks5D4.refresh_from_db()
subs1_2 = self.operation4.getSubstitutions()
self.assertEqual(subs1_2.count(), 0)
subs3_4 = self.operation5.getSubstitutions()
self.assertEqual(subs3_4.count(), 1)
self.assertEqual(self.ks5.constituents().count(), 7)
self.assertEqual(self.ks4D1.definition_formal, r'X4 X1')
self.assertEqual(self.ks4D2.definition_formal, r'X1 X2 DEL DEL D1')
self.assertEqual(self.ks5D4.definition_formal, r'X1 X2 DEL DEL D1 D2 D3')

View File

@ -107,6 +107,18 @@ class RSForm:
model = LibraryItem.objects.get(pk=pk)
return RSForm(model)
def get_dependant(self, target: Iterable[int]) -> set[int]:
''' Get list of constituents depending on target (only 1st degree). '''
result: set[int] = set()
terms = self._graph_term()
formal = self._graph_formal()
definitions = self._graph_text()
for cst_id in target:
result.update(formal.outputs[cst_id])
result.update(terms.outputs[cst_id])
result.update(definitions.outputs[cst_id])
return result
def save(self, *args, **kwargs) -> None:
''' Model wrapper. '''
self.model.save(*args, **kwargs)
@ -239,8 +251,12 @@ class RSForm:
self.save(update_fields=['time_update'])
return result
def insert_copy(self, items: list[Constituenta], position: int = INSERT_LAST,
initial_mapping: Optional[dict[str, str]] = None) -> list[Constituenta]:
def insert_copy(
self,
items: list[Constituenta],
position: int = INSERT_LAST,
initial_mapping: Optional[dict[str, str]] = None
) -> list[Constituenta]:
''' Insert copy of target constituents updating references. '''
count = len(items)
if count == 0:
@ -252,10 +268,12 @@ class RSForm:
indices: dict[str, int] = {}
for (value, _) in CstType.choices:
indices[value] = self.get_max_index(cast(CstType, value))
indices[value] = -1
mapping: dict[str, str] = initial_mapping.copy() if initial_mapping else {}
for cst in items:
if indices[cst.cst_type] == -1:
indices[cst.cst_type] = self.get_max_index(cst.cst_type)
indices[cst.cst_type] = indices[cst.cst_type] + 1
newAlias = f'{get_type_prefix(cst.cst_type)}{indices[cst.cst_type]}'
mapping[cst.alias] = newAlias
@ -382,19 +400,6 @@ class RSForm:
mapping = self._create_reset_mapping()
self.apply_mapping(mapping, change_aliases=True)
def _create_reset_mapping(self) -> dict[str, str]:
bases = cast(dict[str, int], {})
mapping = cast(dict[str, str], {})
for cst_type in CstType.values:
bases[cst_type] = 1
cst_list = self.constituents().order_by('order')
for cst in cst_list:
alias = f'{get_type_prefix(cst.cst_type)}{bases[cst.cst_type]}'
bases[cst.cst_type] += 1
if cst.alias != alias:
mapping[cst.alias] = alias
return mapping
def change_cst_type(self, target: int, new_type: CstType) -> bool:
''' Change type of constituenta generating alias automatically. '''
self.cache.ensure_loaded()
@ -419,6 +424,17 @@ class RSForm:
Constituenta.objects.bulk_update(update_list, ['alias', 'definition_formal', 'term_raw', 'definition_raw'])
self.save(update_fields=['time_update'])
def apply_partial_mapping(self, mapping: dict[str, str], target: list[int]) -> None:
''' Apply rename mapping to target constituents. '''
self.cache.ensure_loaded()
update_list: list[Constituenta] = []
for cst in self.cache.constituents:
if cst.pk in target:
if cst.apply_mapping(mapping):
update_list.append(cst)
Constituenta.objects.bulk_update(update_list, ['definition_formal', 'term_raw', 'definition_raw'])
self.save(update_fields=['time_update'])
def resolve_all_text(self) -> None:
''' Trigger reference resolution for all texts. '''
self.cache.ensure_loaded()
@ -482,6 +498,19 @@ class RSForm:
self.save(update_fields=['time_update'])
return result
def _create_reset_mapping(self) -> dict[str, str]:
bases = cast(dict[str, int], {})
mapping = cast(dict[str, str], {})
for cst_type in CstType.values:
bases[cst_type] = 1
cst_list = self.constituents().order_by('order')
for cst in cst_list:
alias = f'{get_type_prefix(cst.cst_type)}{bases[cst.cst_type]}'
bases[cst.cst_type] += 1
if cst.alias != alias:
mapping[cst.alias] = alias
return mapping
def _shift_positions(self, start: int, shift: int) -> None:
if shift == 0:
return