Add reference resolution to RSForm operations

This commit is contained in:
IRBorisov 2023-08-21 20:20:03 +03:00
parent 9475a6718b
commit 05645a29e8
8 changed files with 154 additions and 49 deletions

View File

@ -6,14 +6,14 @@ class Graph:
''' Directed graph. '''
def __init__(self, graph: Optional[Dict[str, list[str]]]=None):
if graph is None:
self._graph = cast(Dict[str, list[str]], dict())
self._graph = cast(Dict[str, list[str]], {})
else:
self._graph = graph
def contains(self, node_id: str) -> bool:
''' Check if node is in graph. '''
return node_id in self._graph
def has_edge(self, id_from: str, id_to: str) -> bool:
''' Check if edge is in graph. '''
return self.contains(id_from) and id_to in self._graph[id_from]
@ -43,7 +43,7 @@ class Graph:
while position < len(result):
node_id = result[position]
position += 1
if (node_id not in marked):
if node_id not in marked:
marked.add(node_id)
for child_id in self._graph[node_id]:
if child_id not in marked and child_id not in result:
@ -55,19 +55,21 @@ class Graph:
result: list[str] = []
marked: set[str] = set()
for node_id in self._graph.keys():
if node_id not in marked:
to_visit: list[str] = [node_id]
while len(to_visit) > 0:
node = to_visit[-1]
if node in marked:
if node not in result:
result.append(node)
to_visit.remove(node)
else:
marked.add(node)
if len(self._graph[node]) > 0:
for child_id in self._graph[node]:
if child_id not in marked:
to_visit.append(child_id)
if node_id in marked:
continue
to_visit: list[str] = [node_id]
while len(to_visit) > 0:
node = to_visit[-1]
if node in marked:
if node not in result:
result.append(node)
to_visit.remove(node)
else:
marked.add(node)
if len(self._graph[node]) <= 0:
continue
for child_id in self._graph[node]:
if child_id not in marked:
to_visit.append(child_id)
result.reverse()
return result

View File

@ -1,6 +1,6 @@
''' Models: RSForms for conceptual schemas. '''
import json
from typing import Optional
from typing import Iterable, Optional
import pyconcept
from django.db import transaction
from django.db.models import (
@ -11,7 +11,8 @@ from django.core.validators import MinValueValidator
from django.core.exceptions import ValidationError
from django.urls import reverse
from apps.users.models import User
from cctext import Resolver, Entity
from cctext import Resolver, Entity, extract_entities
from .graph import Graph
class CstType(TextChoices):
@ -88,20 +89,34 @@ class RSForm(Model):
return result
@transaction.atomic
def on_term_change(self, alias: str):
def on_term_change(self, changed: Iterable[str]):
''' Trigger cascade resolutions when term changes. '''
pass
# void Thesaurus::OnTermChange(const EntityUID target) {
# auto expansion = TermGraph().ExpandOutputs({ target });
# const auto ordered = TermGraph().Sort(expansion);
# for (const auto entity : ordered) {
# storage.at(entity).term.UpdateFrom(Context());
# }
# expansion = DefGraph().ExpandOutputs(expansion);
# for (const auto entity : expansion) {
# storage.at(entity).definition.UpdateFrom(Context());
# }
graph_terms = self._term_graph()
expansion = graph_terms.expand_outputs(changed)
resolver = self.resolver()
if len(expansion) > 0:
for alias in graph_terms.topological_order():
if alias not in expansion:
continue
cst = self.constituents().get(alias=alias)
resolved = resolver.resolve(cst.term_raw)
if resolved == cst.term_resolved:
continue
cst.set_term_resolved(resolved)
cst.save()
resolver.context[cst.alias] = Entity(cst.alias, resolved)
graph_defs = self._definition_graph()
update_defs = set(expansion + graph_defs.expand_outputs(expansion)).union(changed)
if len(update_defs) == 0:
return
for alias in update_defs:
cst = self.constituents().get(alias=alias)
resolved = resolver.resolve(cst.definition_raw)
if resolved == cst.definition_resolved:
continue
cst.definition_resolved = resolved
cst.save()
@transaction.atomic
def insert_at(self, position: int, alias: str, insert_type: CstType) -> 'Constituenta':
@ -119,7 +134,7 @@ class RSForm(Model):
alias=alias,
cst_type=insert_type
)
self._update_from_core()
self._update_order()
self.save()
result.refresh_from_db()
return result
@ -136,7 +151,7 @@ class RSForm(Model):
alias=alias,
cst_type=insert_type
)
self._update_from_core()
self._update_order()
self.save()
result.refresh_from_db()
return result
@ -162,7 +177,7 @@ class RSForm(Model):
count_moved += 1
update_list.append(cst)
Constituenta.objects.bulk_update(update_list, ['order'])
self._update_from_core()
self._update_order()
self.save()
@transaction.atomic
@ -170,7 +185,8 @@ class RSForm(Model):
''' Delete multiple constituents. Do not check if listCst are from this schema '''
for cst in listCst:
cst.delete()
self._update_from_core()
self._update_order()
self._resolve_all_text()
self.save()
@transaction.atomic
@ -182,12 +198,15 @@ class RSForm(Model):
cst.definition_formal = data.get('definition_formal', '')
cst.term_raw = data.get('term_raw', '')
if cst.term_raw != '':
cst.term_resolved = resolver.resolve(cst.term_raw)
resolved = resolver.resolve(cst.term_raw)
cst.term_resolved = resolved
resolver.context[cst.alias] = Entity(cst.alias, resolved)
cst.definition_raw = data.get('definition_raw', '')
if cst.definition_raw != '':
cst.definition_resolved = resolver.resolve(cst.definition_raw)
cst.save()
self.on_term_change(cst.alias)
self.on_term_change([cst.alias])
cst.refresh_from_db()
return cst
def _insert_new(self, data: dict, insert_after: Optional[str]=None) -> 'Constituenta':
@ -223,7 +242,8 @@ class RSForm(Model):
if prev_cst.pk not in loaded_ids:
prev_cst.delete()
if not skip_update:
self._update_from_core()
self._update_order()
self._resolve_all_text()
self.save()
@staticmethod
@ -264,8 +284,7 @@ class RSForm(Model):
}
@transaction.atomic
def _update_from_core(self):
# TODO: resolve text refs
def _update_order(self):
checked = json.loads(pyconcept.check_schema(json.dumps(self.to_trs())))
update_list = self.constituents().only('id', 'order')
if len(checked['items']) != update_list.count():
@ -288,6 +307,44 @@ class RSForm(Model):
cst_object.save()
order += 1
def _resolve_all_text(self):
graph_terms = self._term_graph()
resolver = Resolver({})
for alias in graph_terms.topological_order():
cst = self.constituents().get(alias=alias)
resolved = resolver.resolve(cst.term_raw)
resolver.context[cst.alias] = Entity(cst.alias, resolved)
if resolved != cst.term_resolved:
cst.term_resolved = resolved
cst.save()
for cst in self.constituents():
resolved = resolver.resolve(cst.definition_raw)
if resolved != cst.definition_resolved:
cst.definition_resolved = resolved
cst.save()
def _term_graph(self) -> Graph:
result = Graph()
cst_list = self.constituents().only('order', 'alias', 'term_raw').order_by('order')
for cst in cst_list:
result.add_node(cst.alias)
for cst in cst_list:
for alias in extract_entities(cst.term_raw):
if result.contains(alias):
result.add_edge(id_from=alias, id_to=cst.alias)
return result
def _definition_graph(self) -> Graph:
result = Graph()
cst_list = self.constituents().only('order', 'alias', 'definition_raw').order_by('order')
for cst in cst_list:
result.add_node(cst.alias)
for cst in cst_list:
for alias in extract_entities(cst.definition_raw):
if result.contains(alias):
result.add_edge(id_from=alias, id_to=cst.alias)
return result
class Constituenta(Model):
''' Constituenta is the base unit for every conceptual schema '''
@ -353,10 +410,18 @@ class Constituenta(Model):
verbose_name_plural = 'Конституенты'
def get_absolute_url(self):
''' URL access. '''
return reverse('constituenta-detail', kwargs={'pk': self.pk})
def __str__(self):
return self.alias
def set_term_resolved(self, new_term: str):
''' Set term and reset forms if needed. '''
if new_term == self.term_resolved:
return
self.term_resolved = new_term
self.term_forms = []
@staticmethod
def create_from_trs(data: dict, schema: RSForm, order: int) -> 'Constituenta':

View File

@ -110,7 +110,7 @@ class ConstituentaSerializer(serializers.ModelSerializer):
term_changed = validated_data['term_resolved'] != instance.term_resolved
result: Constituenta = super().update(instance, validated_data)
if term_changed:
schema.on_term_change(result.alias)
schema.on_term_change([result.alias])
result.refresh_from_db()
schema.save()
return result

View File

@ -6,8 +6,8 @@ from apps.rsform.graph import Graph
class TestGraph(unittest.TestCase):
''' Test class for graph. '''
def test_construction(self):
''' Test graph construction methods. '''
graph = Graph()
self.assertFalse(graph.contains('X1'))
@ -27,7 +27,6 @@ class TestGraph(unittest.TestCase):
self.assertTrue(graph.has_edge('X2', 'X1'))
def test_expand_outputs(self):
''' Test Method: Graph.expand_outputs. '''
graph = Graph({
'X1': ['X2'],
'X2': ['X3', 'X5'],
@ -42,7 +41,6 @@ class TestGraph(unittest.TestCase):
self.assertEqual(graph.expand_outputs(['X2', 'X5']), ['X3', 'X6', 'X1'])
def test_topological_order(self):
''' Test Method: Graph.topological_order. '''
self.assertEqual(Graph().topological_order(), [])
graph = Graph({
'X1': [],

View File

@ -169,6 +169,24 @@ class TestRSForm(TestCase):
self.assertEqual(cst2.schema, schema)
self.assertEqual(cst1.order, 1)
def test_create_cst_resolve(self):
schema = RSForm.objects.create(title='Test')
cst1 = schema.insert_last('X1', CstType.BASE)
cst1.term_raw = '@{X2|datv}'
cst1.definition_raw = '@{X1|datv} @{X2|datv}'
cst1.save()
cst2 = schema.create_cst({
'alias': 'X2',
'cst_type': CstType.BASE,
'term_raw': 'слон',
'definition_raw': '@{X1|plur} @{X2|plur}'
})
cst1.refresh_from_db()
self.assertEqual(cst1.term_resolved, 'слону')
self.assertEqual(cst1.definition_resolved, 'слону слону')
self.assertEqual(cst2.term_resolved, 'слон')
self.assertEqual(cst2.definition_resolved, 'слонам слоны')
def test_delete_cst(self):
schema = RSForm.objects.create(title='Test')
x1 = schema.insert_last('X1', CstType.BASE)
@ -255,7 +273,7 @@ class TestRSForm(TestCase):
'"comment": "Test", "items": '
'[{"entityUID": "' + str(x2.id) + '", "cstType": "basic", "alias": "X1", "convention": "test", '
'"term": {"raw": "t1", "resolved": "t2"}, '
'"definition": {"formal": "123", "text": {"raw": "t3", "resolved": "t4"}}}]}'
'"definition": {"formal": "123", "text": {"raw": "@{X1|datv}", "resolved": "t4"}}}]}'
)
schema.load_trs(input, sync_metadata=True, skip_update=True)
x2.refresh_from_db()
@ -266,7 +284,7 @@ class TestRSForm(TestCase):
self.assertEqual(x2.alias, input['items'][0]['alias'])
self.assertEqual(x2.convention, input['items'][0]['convention'])
self.assertEqual(x2.term_raw, input['items'][0]['term']['raw'])
self.assertEqual(x2.term_resolved, input['items'][0]['term']['resolved'])
self.assertEqual(x2.term_resolved, input['items'][0]['term']['raw'])
self.assertEqual(x2.definition_formal, input['items'][0]['definition']['formal'])
self.assertEqual(x2.definition_raw, input['items'][0]['definition']['text']['raw'])
self.assertEqual(x2.definition_resolved, input['items'][0]['definition']['text']['resolved'])
self.assertEqual(x2.definition_resolved, input['items'][0]['term']['raw'])

View File

@ -5,7 +5,7 @@ from .rumodel import Morphology, SemanticRole, WordTag, morpho, split_grams, com
from .ruparser import PhraseParser, WordToken, Collation
from .reference import EntityReference, ReferenceType, SyntacticReference, parse_reference
from .context import TermForm, Entity, TermContext
from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic
from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic, extract_entities
from .conceptapi import (
parse, normalize,

View File

@ -7,6 +7,17 @@ from .conceptapi import inflect_dependant
from .context import TermContext
from .reference import EntityReference, SyntacticReference, parse_reference, Reference
_REF_ENTITY_PATTERN = re.compile(r'@{([^0-9\-].*?)\|.*?}')
def extract_entities(text: str) -> list[str]:
''' Extract list of entities that are referenced. '''
result: list[str] = []
for segment in re.finditer(_REF_ENTITY_PATTERN, text):
entity = segment.group(1)
if entity not in result:
result.append(entity)
return result
def resolve_entity(ref: EntityReference, context: TermContext) -> str:
''' Resolve entity reference. '''

View File

@ -5,9 +5,20 @@ from typing import cast
from cctext import (
EntityReference, TermContext, Entity, SyntacticReference,
Resolver, ResolvedReference, Position,
resolve_entity, resolve_syntactic
resolve_entity, resolve_syntactic, extract_entities
)
class TestUtils(unittest.TestCase):
''' Test utilitiy methods. '''
def test_extract_entities(self):
self.assertEqual(extract_entities(''), [])
self.assertEqual(extract_entities('@{-1|черны}'), [])
self.assertEqual(extract_entities('@{X1|nomn}'), ['X1'])
self.assertEqual(extract_entities('@{X1|datv}'), ['X1'])
self.assertEqual(extract_entities('@{X1|datv} @{X1|datv} @{X2|datv}'), ['X1', 'X2'])
class TestResolver(unittest.TestCase):
'''Test reference Resolver.'''
def setUp(self):