Add reference resolution to RSForm operations

This commit is contained in:
IRBorisov 2023-08-21 20:20:03 +03:00
parent 9475a6718b
commit 05645a29e8
8 changed files with 154 additions and 49 deletions

View File

@ -6,7 +6,7 @@ class Graph:
''' Directed graph. ''' ''' Directed graph. '''
def __init__(self, graph: Optional[Dict[str, list[str]]]=None): def __init__(self, graph: Optional[Dict[str, list[str]]]=None):
if graph is None: if graph is None:
self._graph = cast(Dict[str, list[str]], dict()) self._graph = cast(Dict[str, list[str]], {})
else: else:
self._graph = graph self._graph = graph
@ -43,7 +43,7 @@ class Graph:
while position < len(result): while position < len(result):
node_id = result[position] node_id = result[position]
position += 1 position += 1
if (node_id not in marked): if node_id not in marked:
marked.add(node_id) marked.add(node_id)
for child_id in self._graph[node_id]: for child_id in self._graph[node_id]:
if child_id not in marked and child_id not in result: if child_id not in marked and child_id not in result:
@ -55,19 +55,21 @@ class Graph:
result: list[str] = [] result: list[str] = []
marked: set[str] = set() marked: set[str] = set()
for node_id in self._graph.keys(): for node_id in self._graph.keys():
if node_id not in marked: if node_id in marked:
to_visit: list[str] = [node_id] continue
while len(to_visit) > 0: to_visit: list[str] = [node_id]
node = to_visit[-1] while len(to_visit) > 0:
if node in marked: node = to_visit[-1]
if node not in result: if node in marked:
result.append(node) if node not in result:
to_visit.remove(node) result.append(node)
else: to_visit.remove(node)
marked.add(node) else:
if len(self._graph[node]) > 0: marked.add(node)
for child_id in self._graph[node]: if len(self._graph[node]) <= 0:
if child_id not in marked: continue
to_visit.append(child_id) for child_id in self._graph[node]:
if child_id not in marked:
to_visit.append(child_id)
result.reverse() result.reverse()
return result return result

View File

@ -1,6 +1,6 @@
''' Models: RSForms for conceptual schemas. ''' ''' Models: RSForms for conceptual schemas. '''
import json import json
from typing import Optional from typing import Iterable, Optional
import pyconcept import pyconcept
from django.db import transaction from django.db import transaction
from django.db.models import ( from django.db.models import (
@ -11,7 +11,8 @@ from django.core.validators import MinValueValidator
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.urls import reverse from django.urls import reverse
from apps.users.models import User from apps.users.models import User
from cctext import Resolver, Entity from cctext import Resolver, Entity, extract_entities
from .graph import Graph
class CstType(TextChoices): class CstType(TextChoices):
@ -88,20 +89,34 @@ class RSForm(Model):
return result return result
@transaction.atomic @transaction.atomic
def on_term_change(self, alias: str): def on_term_change(self, changed: Iterable[str]):
''' Trigger cascade resolutions when term changes. ''' ''' Trigger cascade resolutions when term changes. '''
pass graph_terms = self._term_graph()
# void Thesaurus::OnTermChange(const EntityUID target) { expansion = graph_terms.expand_outputs(changed)
# auto expansion = TermGraph().ExpandOutputs({ target }); resolver = self.resolver()
# const auto ordered = TermGraph().Sort(expansion); if len(expansion) > 0:
# for (const auto entity : ordered) { for alias in graph_terms.topological_order():
# storage.at(entity).term.UpdateFrom(Context()); if alias not in expansion:
# } continue
# expansion = DefGraph().ExpandOutputs(expansion); cst = self.constituents().get(alias=alias)
# for (const auto entity : expansion) { resolved = resolver.resolve(cst.term_raw)
# storage.at(entity).definition.UpdateFrom(Context()); if resolved == cst.term_resolved:
# } continue
cst.set_term_resolved(resolved)
cst.save()
resolver.context[cst.alias] = Entity(cst.alias, resolved)
graph_defs = self._definition_graph()
update_defs = set(expansion + graph_defs.expand_outputs(expansion)).union(changed)
if len(update_defs) == 0:
return
for alias in update_defs:
cst = self.constituents().get(alias=alias)
resolved = resolver.resolve(cst.definition_raw)
if resolved == cst.definition_resolved:
continue
cst.definition_resolved = resolved
cst.save()
@transaction.atomic @transaction.atomic
def insert_at(self, position: int, alias: str, insert_type: CstType) -> 'Constituenta': def insert_at(self, position: int, alias: str, insert_type: CstType) -> 'Constituenta':
@ -119,7 +134,7 @@ class RSForm(Model):
alias=alias, alias=alias,
cst_type=insert_type cst_type=insert_type
) )
self._update_from_core() self._update_order()
self.save() self.save()
result.refresh_from_db() result.refresh_from_db()
return result return result
@ -136,7 +151,7 @@ class RSForm(Model):
alias=alias, alias=alias,
cst_type=insert_type cst_type=insert_type
) )
self._update_from_core() self._update_order()
self.save() self.save()
result.refresh_from_db() result.refresh_from_db()
return result return result
@ -162,7 +177,7 @@ class RSForm(Model):
count_moved += 1 count_moved += 1
update_list.append(cst) update_list.append(cst)
Constituenta.objects.bulk_update(update_list, ['order']) Constituenta.objects.bulk_update(update_list, ['order'])
self._update_from_core() self._update_order()
self.save() self.save()
@transaction.atomic @transaction.atomic
@ -170,7 +185,8 @@ class RSForm(Model):
''' Delete multiple constituents. Do not check if listCst are from this schema ''' ''' Delete multiple constituents. Do not check if listCst are from this schema '''
for cst in listCst: for cst in listCst:
cst.delete() cst.delete()
self._update_from_core() self._update_order()
self._resolve_all_text()
self.save() self.save()
@transaction.atomic @transaction.atomic
@ -182,12 +198,15 @@ class RSForm(Model):
cst.definition_formal = data.get('definition_formal', '') cst.definition_formal = data.get('definition_formal', '')
cst.term_raw = data.get('term_raw', '') cst.term_raw = data.get('term_raw', '')
if cst.term_raw != '': if cst.term_raw != '':
cst.term_resolved = resolver.resolve(cst.term_raw) resolved = resolver.resolve(cst.term_raw)
cst.term_resolved = resolved
resolver.context[cst.alias] = Entity(cst.alias, resolved)
cst.definition_raw = data.get('definition_raw', '') cst.definition_raw = data.get('definition_raw', '')
if cst.definition_raw != '': if cst.definition_raw != '':
cst.definition_resolved = resolver.resolve(cst.definition_raw) cst.definition_resolved = resolver.resolve(cst.definition_raw)
cst.save() cst.save()
self.on_term_change(cst.alias) self.on_term_change([cst.alias])
cst.refresh_from_db()
return cst return cst
def _insert_new(self, data: dict, insert_after: Optional[str]=None) -> 'Constituenta': def _insert_new(self, data: dict, insert_after: Optional[str]=None) -> 'Constituenta':
@ -223,7 +242,8 @@ class RSForm(Model):
if prev_cst.pk not in loaded_ids: if prev_cst.pk not in loaded_ids:
prev_cst.delete() prev_cst.delete()
if not skip_update: if not skip_update:
self._update_from_core() self._update_order()
self._resolve_all_text()
self.save() self.save()
@staticmethod @staticmethod
@ -264,8 +284,7 @@ class RSForm(Model):
} }
@transaction.atomic @transaction.atomic
def _update_from_core(self): def _update_order(self):
# TODO: resolve text refs
checked = json.loads(pyconcept.check_schema(json.dumps(self.to_trs()))) checked = json.loads(pyconcept.check_schema(json.dumps(self.to_trs())))
update_list = self.constituents().only('id', 'order') update_list = self.constituents().only('id', 'order')
if len(checked['items']) != update_list.count(): if len(checked['items']) != update_list.count():
@ -288,6 +307,44 @@ class RSForm(Model):
cst_object.save() cst_object.save()
order += 1 order += 1
def _resolve_all_text(self):
graph_terms = self._term_graph()
resolver = Resolver({})
for alias in graph_terms.topological_order():
cst = self.constituents().get(alias=alias)
resolved = resolver.resolve(cst.term_raw)
resolver.context[cst.alias] = Entity(cst.alias, resolved)
if resolved != cst.term_resolved:
cst.term_resolved = resolved
cst.save()
for cst in self.constituents():
resolved = resolver.resolve(cst.definition_raw)
if resolved != cst.definition_resolved:
cst.definition_resolved = resolved
cst.save()
def _term_graph(self) -> Graph:
result = Graph()
cst_list = self.constituents().only('order', 'alias', 'term_raw').order_by('order')
for cst in cst_list:
result.add_node(cst.alias)
for cst in cst_list:
for alias in extract_entities(cst.term_raw):
if result.contains(alias):
result.add_edge(id_from=alias, id_to=cst.alias)
return result
def _definition_graph(self) -> Graph:
result = Graph()
cst_list = self.constituents().only('order', 'alias', 'definition_raw').order_by('order')
for cst in cst_list:
result.add_node(cst.alias)
for cst in cst_list:
for alias in extract_entities(cst.definition_raw):
if result.contains(alias):
result.add_edge(id_from=alias, id_to=cst.alias)
return result
class Constituenta(Model): class Constituenta(Model):
''' Constituenta is the base unit for every conceptual schema ''' ''' Constituenta is the base unit for every conceptual schema '''
@ -353,11 +410,19 @@ class Constituenta(Model):
verbose_name_plural = 'Конституенты' verbose_name_plural = 'Конституенты'
def get_absolute_url(self): def get_absolute_url(self):
''' URL access. '''
return reverse('constituenta-detail', kwargs={'pk': self.pk}) return reverse('constituenta-detail', kwargs={'pk': self.pk})
def __str__(self): def __str__(self):
return self.alias return self.alias
def set_term_resolved(self, new_term: str):
''' Set term and reset forms if needed. '''
if new_term == self.term_resolved:
return
self.term_resolved = new_term
self.term_forms = []
@staticmethod @staticmethod
def create_from_trs(data: dict, schema: RSForm, order: int) -> 'Constituenta': def create_from_trs(data: dict, schema: RSForm, order: int) -> 'Constituenta':
''' Create constituenta from TRS json ''' ''' Create constituenta from TRS json '''

View File

@ -110,7 +110,7 @@ class ConstituentaSerializer(serializers.ModelSerializer):
term_changed = validated_data['term_resolved'] != instance.term_resolved term_changed = validated_data['term_resolved'] != instance.term_resolved
result: Constituenta = super().update(instance, validated_data) result: Constituenta = super().update(instance, validated_data)
if term_changed: if term_changed:
schema.on_term_change(result.alias) schema.on_term_change([result.alias])
result.refresh_from_db() result.refresh_from_db()
schema.save() schema.save()
return result return result

View File

@ -6,8 +6,8 @@ from apps.rsform.graph import Graph
class TestGraph(unittest.TestCase): class TestGraph(unittest.TestCase):
''' Test class for graph. ''' ''' Test class for graph. '''
def test_construction(self): def test_construction(self):
''' Test graph construction methods. '''
graph = Graph() graph = Graph()
self.assertFalse(graph.contains('X1')) self.assertFalse(graph.contains('X1'))
@ -27,7 +27,6 @@ class TestGraph(unittest.TestCase):
self.assertTrue(graph.has_edge('X2', 'X1')) self.assertTrue(graph.has_edge('X2', 'X1'))
def test_expand_outputs(self): def test_expand_outputs(self):
''' Test Method: Graph.expand_outputs. '''
graph = Graph({ graph = Graph({
'X1': ['X2'], 'X1': ['X2'],
'X2': ['X3', 'X5'], 'X2': ['X3', 'X5'],
@ -42,7 +41,6 @@ class TestGraph(unittest.TestCase):
self.assertEqual(graph.expand_outputs(['X2', 'X5']), ['X3', 'X6', 'X1']) self.assertEqual(graph.expand_outputs(['X2', 'X5']), ['X3', 'X6', 'X1'])
def test_topological_order(self): def test_topological_order(self):
''' Test Method: Graph.topological_order. '''
self.assertEqual(Graph().topological_order(), []) self.assertEqual(Graph().topological_order(), [])
graph = Graph({ graph = Graph({
'X1': [], 'X1': [],

View File

@ -169,6 +169,24 @@ class TestRSForm(TestCase):
self.assertEqual(cst2.schema, schema) self.assertEqual(cst2.schema, schema)
self.assertEqual(cst1.order, 1) self.assertEqual(cst1.order, 1)
def test_create_cst_resolve(self):
schema = RSForm.objects.create(title='Test')
cst1 = schema.insert_last('X1', CstType.BASE)
cst1.term_raw = '@{X2|datv}'
cst1.definition_raw = '@{X1|datv} @{X2|datv}'
cst1.save()
cst2 = schema.create_cst({
'alias': 'X2',
'cst_type': CstType.BASE,
'term_raw': 'слон',
'definition_raw': '@{X1|plur} @{X2|plur}'
})
cst1.refresh_from_db()
self.assertEqual(cst1.term_resolved, 'слону')
self.assertEqual(cst1.definition_resolved, 'слону слону')
self.assertEqual(cst2.term_resolved, 'слон')
self.assertEqual(cst2.definition_resolved, 'слонам слоны')
def test_delete_cst(self): def test_delete_cst(self):
schema = RSForm.objects.create(title='Test') schema = RSForm.objects.create(title='Test')
x1 = schema.insert_last('X1', CstType.BASE) x1 = schema.insert_last('X1', CstType.BASE)
@ -255,7 +273,7 @@ class TestRSForm(TestCase):
'"comment": "Test", "items": ' '"comment": "Test", "items": '
'[{"entityUID": "' + str(x2.id) + '", "cstType": "basic", "alias": "X1", "convention": "test", ' '[{"entityUID": "' + str(x2.id) + '", "cstType": "basic", "alias": "X1", "convention": "test", '
'"term": {"raw": "t1", "resolved": "t2"}, ' '"term": {"raw": "t1", "resolved": "t2"}, '
'"definition": {"formal": "123", "text": {"raw": "t3", "resolved": "t4"}}}]}' '"definition": {"formal": "123", "text": {"raw": "@{X1|datv}", "resolved": "t4"}}}]}'
) )
schema.load_trs(input, sync_metadata=True, skip_update=True) schema.load_trs(input, sync_metadata=True, skip_update=True)
x2.refresh_from_db() x2.refresh_from_db()
@ -266,7 +284,7 @@ class TestRSForm(TestCase):
self.assertEqual(x2.alias, input['items'][0]['alias']) self.assertEqual(x2.alias, input['items'][0]['alias'])
self.assertEqual(x2.convention, input['items'][0]['convention']) self.assertEqual(x2.convention, input['items'][0]['convention'])
self.assertEqual(x2.term_raw, input['items'][0]['term']['raw']) self.assertEqual(x2.term_raw, input['items'][0]['term']['raw'])
self.assertEqual(x2.term_resolved, input['items'][0]['term']['resolved']) self.assertEqual(x2.term_resolved, input['items'][0]['term']['raw'])
self.assertEqual(x2.definition_formal, input['items'][0]['definition']['formal']) self.assertEqual(x2.definition_formal, input['items'][0]['definition']['formal'])
self.assertEqual(x2.definition_raw, input['items'][0]['definition']['text']['raw']) self.assertEqual(x2.definition_raw, input['items'][0]['definition']['text']['raw'])
self.assertEqual(x2.definition_resolved, input['items'][0]['definition']['text']['resolved']) self.assertEqual(x2.definition_resolved, input['items'][0]['term']['raw'])

View File

@ -5,7 +5,7 @@ from .rumodel import Morphology, SemanticRole, WordTag, morpho, split_grams, com
from .ruparser import PhraseParser, WordToken, Collation from .ruparser import PhraseParser, WordToken, Collation
from .reference import EntityReference, ReferenceType, SyntacticReference, parse_reference from .reference import EntityReference, ReferenceType, SyntacticReference, parse_reference
from .context import TermForm, Entity, TermContext from .context import TermForm, Entity, TermContext
from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic, extract_entities
from .conceptapi import ( from .conceptapi import (
parse, normalize, parse, normalize,

View File

@ -7,6 +7,17 @@ from .conceptapi import inflect_dependant
from .context import TermContext from .context import TermContext
from .reference import EntityReference, SyntacticReference, parse_reference, Reference from .reference import EntityReference, SyntacticReference, parse_reference, Reference
_REF_ENTITY_PATTERN = re.compile(r'@{([^0-9\-].*?)\|.*?}')
def extract_entities(text: str) -> list[str]:
''' Extract list of entities that are referenced. '''
result: list[str] = []
for segment in re.finditer(_REF_ENTITY_PATTERN, text):
entity = segment.group(1)
if entity not in result:
result.append(entity)
return result
def resolve_entity(ref: EntityReference, context: TermContext) -> str: def resolve_entity(ref: EntityReference, context: TermContext) -> str:
''' Resolve entity reference. ''' ''' Resolve entity reference. '''

View File

@ -5,9 +5,20 @@ from typing import cast
from cctext import ( from cctext import (
EntityReference, TermContext, Entity, SyntacticReference, EntityReference, TermContext, Entity, SyntacticReference,
Resolver, ResolvedReference, Position, Resolver, ResolvedReference, Position,
resolve_entity, resolve_syntactic resolve_entity, resolve_syntactic, extract_entities
) )
class TestUtils(unittest.TestCase):
''' Test utilitiy methods. '''
def test_extract_entities(self):
self.assertEqual(extract_entities(''), [])
self.assertEqual(extract_entities('@{-1|черны}'), [])
self.assertEqual(extract_entities('@{X1|nomn}'), ['X1'])
self.assertEqual(extract_entities('@{X1|datv}'), ['X1'])
self.assertEqual(extract_entities('@{X1|datv} @{X1|datv} @{X2|datv}'), ['X1', 'X2'])
class TestResolver(unittest.TestCase): class TestResolver(unittest.TestCase):
'''Test reference Resolver.''' '''Test reference Resolver.'''
def setUp(self): def setUp(self):