mirror of
https://github.com/IRBorisov/ConceptPortal.git
synced 2025-06-26 13:00:39 +03:00
Add reference resolution to RSForm operations
This commit is contained in:
parent
9475a6718b
commit
05645a29e8
|
@ -6,14 +6,14 @@ class Graph:
|
|||
''' Directed graph. '''
|
||||
def __init__(self, graph: Optional[Dict[str, list[str]]]=None):
|
||||
if graph is None:
|
||||
self._graph = cast(Dict[str, list[str]], dict())
|
||||
self._graph = cast(Dict[str, list[str]], {})
|
||||
else:
|
||||
self._graph = graph
|
||||
|
||||
def contains(self, node_id: str) -> bool:
|
||||
''' Check if node is in graph. '''
|
||||
return node_id in self._graph
|
||||
|
||||
|
||||
def has_edge(self, id_from: str, id_to: str) -> bool:
|
||||
''' Check if edge is in graph. '''
|
||||
return self.contains(id_from) and id_to in self._graph[id_from]
|
||||
|
@ -43,7 +43,7 @@ class Graph:
|
|||
while position < len(result):
|
||||
node_id = result[position]
|
||||
position += 1
|
||||
if (node_id not in marked):
|
||||
if node_id not in marked:
|
||||
marked.add(node_id)
|
||||
for child_id in self._graph[node_id]:
|
||||
if child_id not in marked and child_id not in result:
|
||||
|
@ -55,19 +55,21 @@ class Graph:
|
|||
result: list[str] = []
|
||||
marked: set[str] = set()
|
||||
for node_id in self._graph.keys():
|
||||
if node_id not in marked:
|
||||
to_visit: list[str] = [node_id]
|
||||
while len(to_visit) > 0:
|
||||
node = to_visit[-1]
|
||||
if node in marked:
|
||||
if node not in result:
|
||||
result.append(node)
|
||||
to_visit.remove(node)
|
||||
else:
|
||||
marked.add(node)
|
||||
if len(self._graph[node]) > 0:
|
||||
for child_id in self._graph[node]:
|
||||
if child_id not in marked:
|
||||
to_visit.append(child_id)
|
||||
if node_id in marked:
|
||||
continue
|
||||
to_visit: list[str] = [node_id]
|
||||
while len(to_visit) > 0:
|
||||
node = to_visit[-1]
|
||||
if node in marked:
|
||||
if node not in result:
|
||||
result.append(node)
|
||||
to_visit.remove(node)
|
||||
else:
|
||||
marked.add(node)
|
||||
if len(self._graph[node]) <= 0:
|
||||
continue
|
||||
for child_id in self._graph[node]:
|
||||
if child_id not in marked:
|
||||
to_visit.append(child_id)
|
||||
result.reverse()
|
||||
return result
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
''' Models: RSForms for conceptual schemas. '''
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import Iterable, Optional
|
||||
import pyconcept
|
||||
from django.db import transaction
|
||||
from django.db.models import (
|
||||
|
@ -11,7 +11,8 @@ from django.core.validators import MinValueValidator
|
|||
from django.core.exceptions import ValidationError
|
||||
from django.urls import reverse
|
||||
from apps.users.models import User
|
||||
from cctext import Resolver, Entity
|
||||
from cctext import Resolver, Entity, extract_entities
|
||||
from .graph import Graph
|
||||
|
||||
|
||||
class CstType(TextChoices):
|
||||
|
@ -88,20 +89,34 @@ class RSForm(Model):
|
|||
return result
|
||||
|
||||
@transaction.atomic
|
||||
def on_term_change(self, alias: str):
|
||||
def on_term_change(self, changed: Iterable[str]):
|
||||
''' Trigger cascade resolutions when term changes. '''
|
||||
pass
|
||||
# void Thesaurus::OnTermChange(const EntityUID target) {
|
||||
# auto expansion = TermGraph().ExpandOutputs({ target });
|
||||
# const auto ordered = TermGraph().Sort(expansion);
|
||||
# for (const auto entity : ordered) {
|
||||
# storage.at(entity).term.UpdateFrom(Context());
|
||||
# }
|
||||
# expansion = DefGraph().ExpandOutputs(expansion);
|
||||
# for (const auto entity : expansion) {
|
||||
# storage.at(entity).definition.UpdateFrom(Context());
|
||||
# }
|
||||
graph_terms = self._term_graph()
|
||||
expansion = graph_terms.expand_outputs(changed)
|
||||
resolver = self.resolver()
|
||||
if len(expansion) > 0:
|
||||
for alias in graph_terms.topological_order():
|
||||
if alias not in expansion:
|
||||
continue
|
||||
cst = self.constituents().get(alias=alias)
|
||||
resolved = resolver.resolve(cst.term_raw)
|
||||
if resolved == cst.term_resolved:
|
||||
continue
|
||||
cst.set_term_resolved(resolved)
|
||||
cst.save()
|
||||
resolver.context[cst.alias] = Entity(cst.alias, resolved)
|
||||
|
||||
graph_defs = self._definition_graph()
|
||||
update_defs = set(expansion + graph_defs.expand_outputs(expansion)).union(changed)
|
||||
if len(update_defs) == 0:
|
||||
return
|
||||
for alias in update_defs:
|
||||
cst = self.constituents().get(alias=alias)
|
||||
resolved = resolver.resolve(cst.definition_raw)
|
||||
if resolved == cst.definition_resolved:
|
||||
continue
|
||||
cst.definition_resolved = resolved
|
||||
cst.save()
|
||||
|
||||
@transaction.atomic
|
||||
def insert_at(self, position: int, alias: str, insert_type: CstType) -> 'Constituenta':
|
||||
|
@ -119,7 +134,7 @@ class RSForm(Model):
|
|||
alias=alias,
|
||||
cst_type=insert_type
|
||||
)
|
||||
self._update_from_core()
|
||||
self._update_order()
|
||||
self.save()
|
||||
result.refresh_from_db()
|
||||
return result
|
||||
|
@ -136,7 +151,7 @@ class RSForm(Model):
|
|||
alias=alias,
|
||||
cst_type=insert_type
|
||||
)
|
||||
self._update_from_core()
|
||||
self._update_order()
|
||||
self.save()
|
||||
result.refresh_from_db()
|
||||
return result
|
||||
|
@ -162,7 +177,7 @@ class RSForm(Model):
|
|||
count_moved += 1
|
||||
update_list.append(cst)
|
||||
Constituenta.objects.bulk_update(update_list, ['order'])
|
||||
self._update_from_core()
|
||||
self._update_order()
|
||||
self.save()
|
||||
|
||||
@transaction.atomic
|
||||
|
@ -170,7 +185,8 @@ class RSForm(Model):
|
|||
''' Delete multiple constituents. Do not check if listCst are from this schema '''
|
||||
for cst in listCst:
|
||||
cst.delete()
|
||||
self._update_from_core()
|
||||
self._update_order()
|
||||
self._resolve_all_text()
|
||||
self.save()
|
||||
|
||||
@transaction.atomic
|
||||
|
@ -182,12 +198,15 @@ class RSForm(Model):
|
|||
cst.definition_formal = data.get('definition_formal', '')
|
||||
cst.term_raw = data.get('term_raw', '')
|
||||
if cst.term_raw != '':
|
||||
cst.term_resolved = resolver.resolve(cst.term_raw)
|
||||
resolved = resolver.resolve(cst.term_raw)
|
||||
cst.term_resolved = resolved
|
||||
resolver.context[cst.alias] = Entity(cst.alias, resolved)
|
||||
cst.definition_raw = data.get('definition_raw', '')
|
||||
if cst.definition_raw != '':
|
||||
cst.definition_resolved = resolver.resolve(cst.definition_raw)
|
||||
cst.save()
|
||||
self.on_term_change(cst.alias)
|
||||
self.on_term_change([cst.alias])
|
||||
cst.refresh_from_db()
|
||||
return cst
|
||||
|
||||
def _insert_new(self, data: dict, insert_after: Optional[str]=None) -> 'Constituenta':
|
||||
|
@ -223,7 +242,8 @@ class RSForm(Model):
|
|||
if prev_cst.pk not in loaded_ids:
|
||||
prev_cst.delete()
|
||||
if not skip_update:
|
||||
self._update_from_core()
|
||||
self._update_order()
|
||||
self._resolve_all_text()
|
||||
self.save()
|
||||
|
||||
@staticmethod
|
||||
|
@ -264,8 +284,7 @@ class RSForm(Model):
|
|||
}
|
||||
|
||||
@transaction.atomic
|
||||
def _update_from_core(self):
|
||||
# TODO: resolve text refs
|
||||
def _update_order(self):
|
||||
checked = json.loads(pyconcept.check_schema(json.dumps(self.to_trs())))
|
||||
update_list = self.constituents().only('id', 'order')
|
||||
if len(checked['items']) != update_list.count():
|
||||
|
@ -288,6 +307,44 @@ class RSForm(Model):
|
|||
cst_object.save()
|
||||
order += 1
|
||||
|
||||
def _resolve_all_text(self):
|
||||
graph_terms = self._term_graph()
|
||||
resolver = Resolver({})
|
||||
for alias in graph_terms.topological_order():
|
||||
cst = self.constituents().get(alias=alias)
|
||||
resolved = resolver.resolve(cst.term_raw)
|
||||
resolver.context[cst.alias] = Entity(cst.alias, resolved)
|
||||
if resolved != cst.term_resolved:
|
||||
cst.term_resolved = resolved
|
||||
cst.save()
|
||||
for cst in self.constituents():
|
||||
resolved = resolver.resolve(cst.definition_raw)
|
||||
if resolved != cst.definition_resolved:
|
||||
cst.definition_resolved = resolved
|
||||
cst.save()
|
||||
|
||||
def _term_graph(self) -> Graph:
|
||||
result = Graph()
|
||||
cst_list = self.constituents().only('order', 'alias', 'term_raw').order_by('order')
|
||||
for cst in cst_list:
|
||||
result.add_node(cst.alias)
|
||||
for cst in cst_list:
|
||||
for alias in extract_entities(cst.term_raw):
|
||||
if result.contains(alias):
|
||||
result.add_edge(id_from=alias, id_to=cst.alias)
|
||||
return result
|
||||
|
||||
def _definition_graph(self) -> Graph:
|
||||
result = Graph()
|
||||
cst_list = self.constituents().only('order', 'alias', 'definition_raw').order_by('order')
|
||||
for cst in cst_list:
|
||||
result.add_node(cst.alias)
|
||||
for cst in cst_list:
|
||||
for alias in extract_entities(cst.definition_raw):
|
||||
if result.contains(alias):
|
||||
result.add_edge(id_from=alias, id_to=cst.alias)
|
||||
return result
|
||||
|
||||
|
||||
class Constituenta(Model):
|
||||
''' Constituenta is the base unit for every conceptual schema '''
|
||||
|
@ -353,10 +410,18 @@ class Constituenta(Model):
|
|||
verbose_name_plural = 'Конституенты'
|
||||
|
||||
def get_absolute_url(self):
|
||||
''' URL access. '''
|
||||
return reverse('constituenta-detail', kwargs={'pk': self.pk})
|
||||
|
||||
def __str__(self):
|
||||
return self.alias
|
||||
|
||||
def set_term_resolved(self, new_term: str):
|
||||
''' Set term and reset forms if needed. '''
|
||||
if new_term == self.term_resolved:
|
||||
return
|
||||
self.term_resolved = new_term
|
||||
self.term_forms = []
|
||||
|
||||
@staticmethod
|
||||
def create_from_trs(data: dict, schema: RSForm, order: int) -> 'Constituenta':
|
||||
|
|
|
@ -110,7 +110,7 @@ class ConstituentaSerializer(serializers.ModelSerializer):
|
|||
term_changed = validated_data['term_resolved'] != instance.term_resolved
|
||||
result: Constituenta = super().update(instance, validated_data)
|
||||
if term_changed:
|
||||
schema.on_term_change(result.alias)
|
||||
schema.on_term_change([result.alias])
|
||||
result.refresh_from_db()
|
||||
schema.save()
|
||||
return result
|
||||
|
|
|
@ -6,8 +6,8 @@ from apps.rsform.graph import Graph
|
|||
|
||||
class TestGraph(unittest.TestCase):
|
||||
''' Test class for graph. '''
|
||||
|
||||
def test_construction(self):
|
||||
''' Test graph construction methods. '''
|
||||
graph = Graph()
|
||||
self.assertFalse(graph.contains('X1'))
|
||||
|
||||
|
@ -27,7 +27,6 @@ class TestGraph(unittest.TestCase):
|
|||
self.assertTrue(graph.has_edge('X2', 'X1'))
|
||||
|
||||
def test_expand_outputs(self):
|
||||
''' Test Method: Graph.expand_outputs. '''
|
||||
graph = Graph({
|
||||
'X1': ['X2'],
|
||||
'X2': ['X3', 'X5'],
|
||||
|
@ -42,7 +41,6 @@ class TestGraph(unittest.TestCase):
|
|||
self.assertEqual(graph.expand_outputs(['X2', 'X5']), ['X3', 'X6', 'X1'])
|
||||
|
||||
def test_topological_order(self):
|
||||
''' Test Method: Graph.topological_order. '''
|
||||
self.assertEqual(Graph().topological_order(), [])
|
||||
graph = Graph({
|
||||
'X1': [],
|
||||
|
|
|
@ -169,6 +169,24 @@ class TestRSForm(TestCase):
|
|||
self.assertEqual(cst2.schema, schema)
|
||||
self.assertEqual(cst1.order, 1)
|
||||
|
||||
def test_create_cst_resolve(self):
|
||||
schema = RSForm.objects.create(title='Test')
|
||||
cst1 = schema.insert_last('X1', CstType.BASE)
|
||||
cst1.term_raw = '@{X2|datv}'
|
||||
cst1.definition_raw = '@{X1|datv} @{X2|datv}'
|
||||
cst1.save()
|
||||
cst2 = schema.create_cst({
|
||||
'alias': 'X2',
|
||||
'cst_type': CstType.BASE,
|
||||
'term_raw': 'слон',
|
||||
'definition_raw': '@{X1|plur} @{X2|plur}'
|
||||
})
|
||||
cst1.refresh_from_db()
|
||||
self.assertEqual(cst1.term_resolved, 'слону')
|
||||
self.assertEqual(cst1.definition_resolved, 'слону слону')
|
||||
self.assertEqual(cst2.term_resolved, 'слон')
|
||||
self.assertEqual(cst2.definition_resolved, 'слонам слоны')
|
||||
|
||||
def test_delete_cst(self):
|
||||
schema = RSForm.objects.create(title='Test')
|
||||
x1 = schema.insert_last('X1', CstType.BASE)
|
||||
|
@ -255,7 +273,7 @@ class TestRSForm(TestCase):
|
|||
'"comment": "Test", "items": '
|
||||
'[{"entityUID": "' + str(x2.id) + '", "cstType": "basic", "alias": "X1", "convention": "test", '
|
||||
'"term": {"raw": "t1", "resolved": "t2"}, '
|
||||
'"definition": {"formal": "123", "text": {"raw": "t3", "resolved": "t4"}}}]}'
|
||||
'"definition": {"formal": "123", "text": {"raw": "@{X1|datv}", "resolved": "t4"}}}]}'
|
||||
)
|
||||
schema.load_trs(input, sync_metadata=True, skip_update=True)
|
||||
x2.refresh_from_db()
|
||||
|
@ -266,7 +284,7 @@ class TestRSForm(TestCase):
|
|||
self.assertEqual(x2.alias, input['items'][0]['alias'])
|
||||
self.assertEqual(x2.convention, input['items'][0]['convention'])
|
||||
self.assertEqual(x2.term_raw, input['items'][0]['term']['raw'])
|
||||
self.assertEqual(x2.term_resolved, input['items'][0]['term']['resolved'])
|
||||
self.assertEqual(x2.term_resolved, input['items'][0]['term']['raw'])
|
||||
self.assertEqual(x2.definition_formal, input['items'][0]['definition']['formal'])
|
||||
self.assertEqual(x2.definition_raw, input['items'][0]['definition']['text']['raw'])
|
||||
self.assertEqual(x2.definition_resolved, input['items'][0]['definition']['text']['resolved'])
|
||||
self.assertEqual(x2.definition_resolved, input['items'][0]['term']['raw'])
|
||||
|
|
|
@ -5,7 +5,7 @@ from .rumodel import Morphology, SemanticRole, WordTag, morpho, split_grams, com
|
|||
from .ruparser import PhraseParser, WordToken, Collation
|
||||
from .reference import EntityReference, ReferenceType, SyntacticReference, parse_reference
|
||||
from .context import TermForm, Entity, TermContext
|
||||
from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic
|
||||
from .resolver import Position, Resolver, ResolvedReference, resolve_entity, resolve_syntactic, extract_entities
|
||||
|
||||
from .conceptapi import (
|
||||
parse, normalize,
|
||||
|
|
|
@ -7,6 +7,17 @@ from .conceptapi import inflect_dependant
|
|||
from .context import TermContext
|
||||
from .reference import EntityReference, SyntacticReference, parse_reference, Reference
|
||||
|
||||
_REF_ENTITY_PATTERN = re.compile(r'@{([^0-9\-].*?)\|.*?}')
|
||||
|
||||
def extract_entities(text: str) -> list[str]:
|
||||
''' Extract list of entities that are referenced. '''
|
||||
result: list[str] = []
|
||||
for segment in re.finditer(_REF_ENTITY_PATTERN, text):
|
||||
entity = segment.group(1)
|
||||
if entity not in result:
|
||||
result.append(entity)
|
||||
return result
|
||||
|
||||
|
||||
def resolve_entity(ref: EntityReference, context: TermContext) -> str:
|
||||
''' Resolve entity reference. '''
|
||||
|
|
|
@ -5,9 +5,20 @@ from typing import cast
|
|||
from cctext import (
|
||||
EntityReference, TermContext, Entity, SyntacticReference,
|
||||
Resolver, ResolvedReference, Position,
|
||||
resolve_entity, resolve_syntactic
|
||||
resolve_entity, resolve_syntactic, extract_entities
|
||||
)
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
''' Test utilitiy methods. '''
|
||||
def test_extract_entities(self):
|
||||
self.assertEqual(extract_entities(''), [])
|
||||
self.assertEqual(extract_entities('@{-1|черны}'), [])
|
||||
self.assertEqual(extract_entities('@{X1|nomn}'), ['X1'])
|
||||
self.assertEqual(extract_entities('@{X1|datv}'), ['X1'])
|
||||
self.assertEqual(extract_entities('@{X1|datv} @{X1|datv} @{X2|datv}'), ['X1', 'X2'])
|
||||
|
||||
|
||||
class TestResolver(unittest.TestCase):
|
||||
'''Test reference Resolver.'''
|
||||
def setUp(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user