R: RSForm cache and transaction.atomic
Some checks are pending
Backend CI / build (3.12) (push) Waiting to run

This commit is contained in:
Ivan 2024-08-07 21:54:50 +03:00
parent 30a80de424
commit ec358911fb
13 changed files with 294 additions and 244 deletions

View File

@ -36,7 +36,7 @@ class LibraryItemSerializer(serializers.ModelSerializer):
class LibraryItemCloneSerializer(serializers.ModelSerializer): class LibraryItemCloneSerializer(serializers.ModelSerializer):
''' Serializer: LibraryItem cloning. ''' ''' Serializer: LibraryItem cloning. '''
items = PKField(many=True, required=False, queryset=Constituenta.objects.all()) items = PKField(many=True, required=False, queryset=Constituenta.objects.all().only('pk'))
class Meta: class Meta:
''' serializer metadata. ''' ''' serializer metadata. '''

View File

@ -142,7 +142,7 @@ class TestVersionViews(EndpointTester):
version_id = self._create_version(data=data) version_id = self._create_version(data=data)
invalid_id = version_id + 1337 invalid_id = version_id + 1337
d1.delete() self.owned.delete_cst([d1])
x3 = self.owned.insert_new('X3') x3 = self.owned.insert_new('X3')
x1.order = x3.order x1.order = x3.order
x1.convention = 'Test2' x1.convention = 'Test2'

View File

@ -4,6 +4,7 @@ from typing import cast
from django.db import transaction from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django.http import HttpResponse
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import generics from rest_framework import generics
from rest_framework import status as c from rest_framework import status as c
@ -79,7 +80,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
} }
) )
@action(detail=True, methods=['post'], url_path='clone') @action(detail=True, methods=['post'], url_path='clone')
def clone(self, request: Request, pk): def clone(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Create deep copy of library item. ''' ''' Endpoint: Create deep copy of library item. '''
serializer = s.LibraryItemCloneSerializer(data=request.data) serializer = s.LibraryItemCloneSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -139,7 +140,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
}, },
) )
@action(detail=True, methods=['delete']) @action(detail=True, methods=['delete'])
def unsubscribe(self, request: Request, pk): def unsubscribe(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Unsubscribe current user from item. ''' ''' Endpoint: Unsubscribe current user from item. '''
item = self._get_item() item = self._get_item()
m.Subscription.unsubscribe(user=cast(int, self.request.user.pk), item=item.pk) m.Subscription.unsubscribe(user=cast(int, self.request.user.pk), item=item.pk)
@ -156,7 +157,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
} }
) )
@action(detail=True, methods=['patch'], url_path='set-owner') @action(detail=True, methods=['patch'], url_path='set-owner')
def set_owner(self, request: Request, pk): def set_owner(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Set item owner. ''' ''' Endpoint: Set item owner. '''
item = self._get_item() item = self._get_item()
serializer = s.UserTargetSerializer(data=request.data) serializer = s.UserTargetSerializer(data=request.data)
@ -188,7 +189,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
} }
) )
@action(detail=True, methods=['patch'], url_path='set-location') @action(detail=True, methods=['patch'], url_path='set-location')
def set_location(self, request: Request, pk): def set_location(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Set item location. ''' ''' Endpoint: Set item location. '''
item = self._get_item() item = self._get_item()
serializer = s.LocationSerializer(data=request.data) serializer = s.LocationSerializer(data=request.data)
@ -222,7 +223,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
} }
) )
@action(detail=True, methods=['patch'], url_path='set-access-policy') @action(detail=True, methods=['patch'], url_path='set-access-policy')
def set_access_policy(self, request: Request, pk): def set_access_policy(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Set item AccessPolicy. ''' ''' Endpoint: Set item AccessPolicy. '''
item = self._get_item() item = self._get_item()
serializer = s.AccessPolicySerializer(data=request.data) serializer = s.AccessPolicySerializer(data=request.data)
@ -253,7 +254,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
} }
) )
@action(detail=True, methods=['patch'], url_path='set-editors') @action(detail=True, methods=['patch'], url_path='set-editors')
def set_editors(self, request: Request, pk): def set_editors(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Set list of editors for item. ''' ''' Endpoint: Set list of editors for item. '''
item = self._get_item() item = self._get_item()
serializer = s.UsersListSerializer(data=request.data) serializer = s.UsersListSerializer(data=request.data)

View File

@ -1,6 +1,7 @@
''' Endpoints for versions. ''' ''' Endpoints for versions. '''
from typing import cast from typing import cast
from django.db import transaction
from django.http import HttpResponse from django.http import HttpResponse
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import generics from rest_framework import generics
@ -40,11 +41,12 @@ class VersionViewset(
} }
) )
@action(detail=True, methods=['patch'], url_path='restore') @action(detail=True, methods=['patch'], url_path='restore')
def restore(self, request: Request, pk): def restore(self, request: Request, pk) -> HttpResponse:
''' Restore version data into current item. ''' ''' Restore version data into current item. '''
version = cast(m.Version, self.get_object()) version = cast(m.Version, self.get_object())
item = cast(m.LibraryItem, version.item) item = cast(m.LibraryItem, version.item)
RSFormSerializer(item).restore_from_version(version.data) with transaction.atomic():
RSFormSerializer(item).restore_from_version(version.data)
return Response( return Response(
status=c.HTTP_200_OK, status=c.HTTP_200_OK,
data=RSFormParseSerializer(item).data data=RSFormParseSerializer(item).data
@ -61,7 +63,7 @@ class VersionViewset(
} }
) )
@api_view(['GET']) @api_view(['GET'])
def export_file(request: Request, pk: int): def export_file(request: Request, pk: int) -> HttpResponse:
''' Endpoint: Download Exteor compatible file for versioned data. ''' ''' Endpoint: Download Exteor compatible file for versioned data. '''
try: try:
version = m.Version.objects.get(pk=pk) version = m.Version.objects.get(pk=pk)
@ -88,7 +90,7 @@ def export_file(request: Request, pk: int):
) )
@api_view(['POST']) @api_view(['POST'])
@permission_classes([permissions.GlobalUser]) @permission_classes([permissions.GlobalUser])
def create_version(request: Request, pk_item: int): def create_version(request: Request, pk_item: int) -> HttpResponse:
''' Endpoint: Create new version for RSForm copying current content. ''' ''' Endpoint: Create new version for RSForm copying current content. '''
try: try:
item = m.LibraryItem.objects.get(pk=pk_item) item = m.LibraryItem.objects.get(pk=pk_item)
@ -125,7 +127,7 @@ def create_version(request: Request, pk_item: int):
} }
) )
@api_view(['GET']) @api_view(['GET'])
def retrieve_version(request: Request, pk_item: int, pk_version: int): def retrieve_version(request: Request, pk_item: int, pk_version: int) -> HttpResponse:
''' Endpoint: Retrieve version for RSForm. ''' ''' Endpoint: Retrieve version for RSForm. '''
try: try:
item = m.LibraryItem.objects.get(pk=pk_item) item = m.LibraryItem.objects.get(pk=pk_item)

View File

@ -1,7 +1,6 @@
''' Models: OSS API. ''' ''' Models: OSS API. '''
from typing import Optional from typing import Optional
from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from apps.library.models import Editor, LibraryItem, LibraryItemType from apps.library.models import Editor, LibraryItem, LibraryItemType
@ -31,11 +30,11 @@ class OperationSchema:
model = LibraryItem.objects.get(pk=pk) model = LibraryItem.objects.get(pk=pk)
return OperationSchema(model) return OperationSchema(model)
def save(self, *args, **kwargs): def save(self, *args, **kwargs) -> None:
''' Save wrapper. ''' ''' Save wrapper. '''
self.model.save(*args, **kwargs) self.model.save(*args, **kwargs)
def refresh_from_db(self): def refresh_from_db(self) -> None:
''' Model wrapper. ''' ''' Model wrapper. '''
self.model.refresh_from_db() self.model.refresh_from_db()
@ -59,7 +58,7 @@ class OperationSchema:
location=self.model.location location=self.model.location
) )
def update_positions(self, data: list[dict]): def update_positions(self, data: list[dict]) -> None:
''' Update positions. ''' ''' Update positions. '''
lookup = {x['id']: x for x in data} lookup = {x['id']: x for x in data}
operations = self.operations() operations = self.operations()
@ -69,7 +68,6 @@ class OperationSchema:
item.position_y = lookup[item.pk]['position_y'] item.position_y = lookup[item.pk]['position_y']
Operation.objects.bulk_update(operations, ['position_x', 'position_y']) Operation.objects.bulk_update(operations, ['position_x', 'position_y'])
@transaction.atomic
def create_operation(self, **kwargs) -> Operation: def create_operation(self, **kwargs) -> Operation:
''' Insert new operation. ''' ''' Insert new operation. '''
result = Operation.objects.create(oss=self.model, **kwargs) result = Operation.objects.create(oss=self.model, **kwargs)
@ -77,7 +75,6 @@ class OperationSchema:
result.refresh_from_db() result.refresh_from_db()
return result return result
@transaction.atomic
def delete_operation(self, operation: Operation): def delete_operation(self, operation: Operation):
''' Delete operation. ''' ''' Delete operation. '''
operation.delete() operation.delete()
@ -87,8 +84,7 @@ class OperationSchema:
self.save() self.save()
@transaction.atomic def set_input(self, target: Operation, schema: Optional[LibraryItem]) -> None:
def set_input(self, target: Operation, schema: Optional[LibraryItem]):
''' Set input schema for operation. ''' ''' Set input schema for operation. '''
if schema == target.result: if schema == target.result:
return return
@ -104,8 +100,7 @@ class OperationSchema:
self.save() self.save()
@transaction.atomic def set_arguments(self, operation: Operation, arguments: list[Operation]) -> None:
def set_arguments(self, operation: Operation, arguments: list[Operation]):
''' Set arguments to operation. ''' ''' Set arguments to operation. '''
processed: list[Operation] = [] processed: list[Operation] = []
changed = False changed = False
@ -125,8 +120,7 @@ class OperationSchema:
# TODO: trigger on_change effects # TODO: trigger on_change effects
self.save() self.save()
@transaction.atomic def set_substitutions(self, target: Operation, substitutes: list[dict]) -> None:
def set_substitutions(self, target: Operation, substitutes: list[dict]):
''' Clear all arguments for operation. ''' ''' Clear all arguments for operation. '''
processed: list[dict] = [] processed: list[dict] = []
changed = False changed = False
@ -157,7 +151,6 @@ class OperationSchema:
self.save() self.save()
@transaction.atomic
def create_input(self, operation: Operation) -> RSForm: def create_input(self, operation: Operation) -> RSForm:
''' Create input RSForm. ''' ''' Create input RSForm. '''
schema = RSForm.create( schema = RSForm.create(
@ -175,7 +168,6 @@ class OperationSchema:
self.save() self.save()
return schema return schema
@transaction.atomic
def execute_operation(self, operation: Operation) -> bool: def execute_operation(self, operation: Operation) -> bool:
''' Execute target operation. ''' ''' Execute target operation. '''
schemas: list[LibraryItem] = [arg.argument.result for arg in operation.getArguments()] schemas: list[LibraryItem] = [arg.argument.result for arg in operation.getArguments()]

View File

@ -2,6 +2,7 @@
from typing import cast from typing import cast
from django.db import transaction from django.db import transaction
from django.http import HttpResponse
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import generics, serializers from rest_framework import generics, serializers
from rest_framework import status as c from rest_framework import status as c
@ -61,7 +62,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['get'], url_path='details') @action(detail=True, methods=['get'], url_path='details')
def details(self, request: Request, pk): def details(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Detailed OSS data. ''' ''' Endpoint: Detailed OSS data. '''
serializer = s.OperationSchemaSerializer(self._get_item()) serializer = s.OperationSchemaSerializer(self._get_item())
return Response( return Response(
@ -80,7 +81,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['patch'], url_path='update-positions') @action(detail=True, methods=['patch'], url_path='update-positions')
def update_positions(self, request: Request, pk): def update_positions(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Update operations positions. ''' ''' Endpoint: Update operations positions. '''
serializer = s.PositionsSerializer(data=request.data) serializer = s.PositionsSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -99,7 +100,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['post'], url_path='create-operation') @action(detail=True, methods=['post'], url_path='create-operation')
def create_operation(self, request: Request, pk): def create_operation(self, request: Request, pk) -> HttpResponse:
''' Create new operation. ''' ''' Create new operation. '''
serializer = s.OperationCreateSerializer(data=request.data) serializer = s.OperationCreateSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -135,7 +136,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['patch'], url_path='delete-operation') @action(detail=True, methods=['patch'], url_path='delete-operation')
def delete_operation(self, request: Request, pk): def delete_operation(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Delete operation. ''' ''' Endpoint: Delete operation. '''
serializer = s.OperationTargetSerializer( serializer = s.OperationTargetSerializer(
data=request.data, data=request.data,
@ -165,7 +166,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['patch'], url_path='create-input') @action(detail=True, methods=['patch'], url_path='create-input')
def create_input(self, request: Request, pk): def create_input(self, request: Request, pk) -> HttpResponse:
''' Create new input RSForm. ''' ''' Create new input RSForm. '''
serializer = s.OperationTargetSerializer( serializer = s.OperationTargetSerializer(
data=request.data, data=request.data,
@ -208,7 +209,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['patch'], url_path='set-input') @action(detail=True, methods=['patch'], url_path='set-input')
def set_input(self, request: Request, pk): def set_input(self, request: Request, pk) -> HttpResponse:
''' Set input schema for target operation. ''' ''' Set input schema for target operation. '''
serializer = s.SetOperationInputSerializer( serializer = s.SetOperationInputSerializer(
data=request.data, data=request.data,
@ -238,7 +239,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['patch'], url_path='update-operation') @action(detail=True, methods=['patch'], url_path='update-operation')
def update_operation(self, request: Request, pk): def update_operation(self, request: Request, pk) -> HttpResponse:
''' Update operation arguments and parameters. ''' ''' Update operation arguments and parameters. '''
serializer = s.OperationUpdateSerializer( serializer = s.OperationUpdateSerializer(
data=request.data, data=request.data,
@ -283,7 +284,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=True, methods=['post'], url_path='execute-operation') @action(detail=True, methods=['post'], url_path='execute-operation')
def execute_operation(self, request: Request, pk): def execute_operation(self, request: Request, pk) -> HttpResponse:
''' Execute operation. ''' ''' Execute operation. '''
serializer = s.OperationTargetSerializer( serializer = s.OperationTargetSerializer(
data=request.data, data=request.data,
@ -323,7 +324,7 @@ class OssViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retriev
} }
) )
@action(detail=False, methods=['post'], url_path='get-predecessor') @action(detail=False, methods=['post'], url_path='get-predecessor')
def get_predecessor(self, request: Request): def get_predecessor(self, request: Request) -> HttpResponse:
''' Get predecessor. ''' ''' Get predecessor. '''
# TODO: add tests for this method # TODO: add tests for this method
serializer = CstTargetSerializer(data=request.data) serializer = CstTargetSerializer(data=request.data)

View File

@ -99,8 +99,6 @@ class Constituenta(Model):
def set_term_resolved(self, new_term: str): def set_term_resolved(self, new_term: str):
''' Set term and reset forms if needed. ''' ''' Set term and reset forms if needed. '''
if new_term == self.term_resolved:
return
self.term_resolved = new_term self.term_resolved = new_term
self.term_forms = [] self.term_forms = []
@ -113,10 +111,6 @@ class Constituenta(Model):
if expression != self.definition_formal: if expression != self.definition_formal:
modified = True modified = True
self.definition_formal = expression self.definition_formal = expression
convention = apply_pattern(self.convention, mapping, _GLOBAL_ID_PATTERN)
if convention != self.convention:
modified = True
self.convention = convention
term = apply_pattern(self.term_raw, mapping, _REF_ENTITY_PATTERN) term = apply_pattern(self.term_raw, mapping, _REF_ENTITY_PATTERN)
if term != self.term_raw: if term != self.term_raw:
modified = True modified = True

View File

@ -1,10 +1,9 @@
''' Models: RSForm API. ''' ''' Models: RSForm API. '''
from copy import deepcopy from copy import deepcopy
from typing import Optional, cast from typing import Iterable, Optional, cast
from cctext import Entity, Resolver, TermForm, extract_entities, split_grams from cctext import Entity, Resolver, TermForm, extract_entities, split_grams
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import transaction
from django.db.models import QuerySet from django.db.models import QuerySet
from apps.library.models import LibraryItem, LibraryItemType, Version from apps.library.models import LibraryItem, LibraryItemType, Version
@ -30,8 +29,71 @@ _INSERT_LAST: int = -1
class RSForm: class RSForm:
''' RSForm is math form of conceptual schema. ''' ''' RSForm is math form of conceptual schema. '''
class Cache:
''' Cache for RSForm constituents. '''
def __init__(self, schema: 'RSForm'):
self._schema = schema
self.constituents: list[Constituenta] = []
self.by_id: dict[int, Constituenta] = {}
self.by_alias: dict[str, Constituenta] = {}
self.is_loaded = False
def reload(self) -> None:
self.constituents = list(
self._schema.constituents().only(
'order',
'alias',
'cst_type',
'definition_formal',
'term_raw',
'definition_raw'
).order_by('order')
)
self.by_id = {cst.pk: cst for cst in self.constituents}
self.by_alias = {cst.alias: cst for cst in self.constituents}
self.is_loaded = True
def ensure(self) -> None:
if not self.is_loaded:
self.reload()
def clear(self) -> None:
self.constituents = []
self.by_id = {}
self.by_alias = {}
self.is_loaded = False
def insert(self, cst: Constituenta) -> None:
if self.is_loaded:
self.constituents.insert(cst.order - 1, cst)
self.by_id[cst.pk] = cst
self.by_alias[cst.alias] = cst
def insert_multi(self, items: Iterable[Constituenta]) -> None:
if self.is_loaded:
for cst in items:
self.constituents.insert(cst.order - 1, cst)
self.by_id[cst.pk] = cst
self.by_alias[cst.alias] = cst
def remove(self, target: Constituenta) -> None:
if self.is_loaded:
self.constituents.remove(target)
del self.by_id[target.pk]
del self.by_alias[target.alias]
def remove_multi(self, target: Iterable[Constituenta]) -> None:
if self.is_loaded:
for cst in target:
self.constituents.remove(cst)
del self.by_id[cst.pk]
del self.by_alias[cst.alias]
def __init__(self, model: LibraryItem): def __init__(self, model: LibraryItem):
self.model = model self.model = model
self.cache: RSForm.Cache = RSForm.Cache(self)
@staticmethod @staticmethod
def create(**kwargs) -> 'RSForm': def create(**kwargs) -> 'RSForm':
@ -45,11 +107,11 @@ class RSForm:
model = LibraryItem.objects.get(pk=pk) model = LibraryItem.objects.get(pk=pk)
return RSForm(model) return RSForm(model)
def save(self, *args, **kwargs): def save(self, *args, **kwargs) -> None:
''' Model wrapper. ''' ''' Model wrapper. '''
self.model.save(*args, **kwargs) self.model.save(*args, **kwargs)
def refresh_from_db(self): def refresh_from_db(self) -> None:
''' Model wrapper. ''' ''' Model wrapper. '''
self.model.refresh_from_db() self.model.refresh_from_db()
@ -60,7 +122,7 @@ class RSForm:
def resolver(self) -> Resolver: def resolver(self) -> Resolver:
''' Create resolver for text references based on schema terms. ''' ''' Create resolver for text references based on schema terms. '''
result = Resolver({}) result = Resolver({})
for cst in self.constituents(): for cst in self.constituents().only('alias', 'term_resolved', 'term_forms'):
entity = Entity( entity = Entity(
alias=cst.alias, alias=cst.alias,
nominal=cst.term_resolved, nominal=cst.term_resolved,
@ -76,49 +138,53 @@ class RSForm:
''' Access semantic information on constituents. ''' ''' Access semantic information on constituents. '''
return SemanticInfo(self) return SemanticInfo(self)
@transaction.atomic def on_term_change(self, changed: list[int]) -> None:
def on_term_change(self, changed: list[int]):
''' Trigger cascade resolutions when term changes. ''' ''' Trigger cascade resolutions when term changes. '''
self.cache.ensure()
graph_terms = self._graph_term() graph_terms = self._graph_term()
expansion = graph_terms.expand_outputs(changed) expansion = graph_terms.expand_outputs(changed)
expanded_change = changed + expansion expanded_change = changed + expansion
update_list: list[Constituenta] = []
resolver = self.resolver() resolver = self.resolver()
if len(expansion) > 0: if len(expansion) > 0:
for cst_id in graph_terms.topological_order(): for cst_id in graph_terms.topological_order():
if cst_id not in expansion: if cst_id not in expansion:
continue continue
cst = self.constituents().get(id=cst_id) cst = self.cache.by_id[cst_id]
resolved = resolver.resolve(cst.term_raw) resolved = resolver.resolve(cst.term_raw)
if resolved == cst.term_resolved: if resolved == resolver.context[cst.alias].get_nominal():
continue continue
cst.set_term_resolved(resolved) cst.set_term_resolved(resolved)
cst.save() update_list.append(cst)
resolver.context[cst.alias] = Entity(cst.alias, resolved) resolver.context[cst.alias] = Entity(cst.alias, resolved)
Constituenta.objects.bulk_update(update_list, ['term_resolved'])
graph_defs = self._graph_text() graph_defs = self._graph_text()
update_defs = set(expansion + graph_defs.expand_outputs(expanded_change)).union(changed) update_defs = set(expansion + graph_defs.expand_outputs(expanded_change)).union(changed)
update_list = []
if len(update_defs) == 0: if len(update_defs) == 0:
return return
for cst_id in update_defs: for cst_id in update_defs:
cst = self.constituents().get(id=cst_id) cst = self.cache.by_id[cst_id]
resolved = resolver.resolve(cst.definition_raw) resolved = resolver.resolve(cst.definition_raw)
if resolved == cst.definition_resolved:
continue
cst.definition_resolved = resolved cst.definition_resolved = resolved
cst.save() update_list.append(cst)
Constituenta.objects.bulk_update(update_list, ['definition_resolved'])
def get_max_index(self, cst_type: CstType) -> int: def get_max_index(self, cst_type: CstType) -> int:
''' Get maximum alias index for specific CstType. ''' ''' Get maximum alias index for specific CstType. '''
result: int = 0 result: int = 0
items = Constituenta.objects \ cst_list: Iterable[Constituenta] = []
.filter(schema=self.model, cst_type=cst_type) \ if not self.cache.is_loaded:
.order_by('-alias') \ cst_list = Constituenta.objects \
.values_list('alias', flat=True) .filter(schema=self.model, cst_type=cst_type) \
for alias in items: .only('alias')
result = max(result, int(alias[1:])) else:
cst_list = [cst for cst in self.cache.constituents if cst.cst_type == cst_type]
for cst in cst_list:
result = max(result, int(cst.alias[1:]))
return result return result
@transaction.atomic
def create_cst(self, data: dict, insert_after: Optional[Constituenta] = None) -> Constituenta: def create_cst(self, data: dict, insert_after: Optional[Constituenta] = None) -> Constituenta:
''' Create new cst from data. ''' ''' Create new cst from data. '''
if insert_after is None: if insert_after is None:
@ -142,11 +208,11 @@ class RSForm:
result.definition_resolved = resolver.resolve(result.definition_raw) result.definition_resolved = resolver.resolve(result.definition_raw)
result.save() result.save()
self.cache.insert(result)
self.on_term_change([result.pk]) self.on_term_change([result.pk])
result.refresh_from_db() result.refresh_from_db()
return result return result
@transaction.atomic
def insert_new( def insert_new(
self, self,
alias: str, alias: str,
@ -169,17 +235,17 @@ class RSForm:
cst_type=cst_type, cst_type=cst_type,
**kwargs **kwargs
) )
self.cache.insert(result)
self.save() self.save()
result.refresh_from_db()
return result return result
@transaction.atomic
def insert_copy(self, items: list[Constituenta], position: int = _INSERT_LAST) -> list[Constituenta]: def insert_copy(self, items: list[Constituenta], position: int = _INSERT_LAST) -> list[Constituenta]:
''' Insert copy of target constituents updating references. ''' ''' Insert copy of target constituents updating references. '''
count = len(items) count = len(items)
if count == 0: if count == 0:
return [] return []
self.cache.ensure()
position = self._get_insert_position(position) position = self._get_insert_position(position)
self._shift_positions(position, count) self._shift_positions(position, count)
@ -200,62 +266,65 @@ class RSForm:
cst.order = position cst.order = position
cst.alias = mapping[cst.alias] cst.alias = mapping[cst.alias]
cst.apply_mapping(mapping) cst.apply_mapping(mapping)
cst.save()
position = position + 1 position = position + 1
new_cst = Constituenta.objects.bulk_create(result)
self.cache.insert_multi(new_cst)
self.save() self.save()
return result return result
@transaction.atomic def move_cst(self, target: list[Constituenta], destination: int) -> None:
def move_cst(self, listCst: list[Constituenta], target: int):
''' Move list of constituents to specific position ''' ''' Move list of constituents to specific position '''
count_moved = 0 count_moved = 0
count_top = 0 count_top = 0
count_bot = 0 count_bot = 0
size = len(listCst) size = len(target)
update_list = []
for cst in self.constituents().only('order').order_by('order'): cst_list: Iterable[Constituenta] = []
if cst not in listCst: if not self.cache.is_loaded:
if count_top + 1 < target: cst_list = self.constituents().only('order').order_by('order')
cst.order = count_top + 1 else:
count_top += 1 cst_list = self.cache.constituents
else: for cst in cst_list:
cst.order = target + size + count_bot if cst in target:
count_bot += 1 cst.order = destination + count_moved
else:
cst.order = target + count_moved
count_moved += 1 count_moved += 1
update_list.append(cst) elif count_top + 1 < destination:
Constituenta.objects.bulk_update(update_list, ['order']) cst.order = count_top + 1
count_top += 1
else:
cst.order = destination + size + count_bot
count_bot += 1
Constituenta.objects.bulk_update(cst_list, ['order'])
self.save() self.save()
@transaction.atomic def delete_cst(self, target: Iterable[Constituenta]) -> None:
def delete_cst(self, listCst):
''' Delete multiple constituents. Do not check if listCst are from this schema. ''' ''' Delete multiple constituents. Do not check if listCst are from this schema. '''
for cst in listCst: self.cache.remove_multi(target)
cst.delete() Constituenta.objects.filter(pk__in=[cst.pk for cst in target]).delete()
self._reset_order() self._reset_order()
self.resolve_all_text() self.resolve_all_text()
self.save() self.save()
@transaction.atomic
def substitute( def substitute(
self, self,
original: Constituenta, original: Constituenta,
substitution: Constituenta substitution: Constituenta
): ) -> None:
''' Execute constituenta substitution. ''' ''' Execute constituenta substitution. '''
assert original.pk != substitution.pk assert original.pk != substitution.pk
mapping = {original.alias: substitution.alias} mapping = {original.alias: substitution.alias}
self.apply_mapping(mapping) self.apply_mapping(mapping)
self.cache.remove(self.cache.by_id[original.pk])
original.delete() original.delete()
self.on_term_change([substitution.pk]) self.on_term_change([substitution.pk])
def restore_order(self): def restore_order(self) -> None:
''' Restore order based on types and term graph. ''' ''' Restore order based on types and term graph. '''
manager = _OrderManager(self) manager = _OrderManager(self)
manager.restore_order() manager.restore_order()
def reset_aliases(self): def reset_aliases(self) -> None:
''' Recreate all aliases based on constituents order. ''' ''' Recreate all aliases based on constituents order. '''
mapping = self._create_reset_mapping() mapping = self._create_reset_mapping()
self.apply_mapping(mapping, change_aliases=True) self.apply_mapping(mapping, change_aliases=True)
@ -273,33 +342,36 @@ class RSForm:
mapping[cst.alias] = alias mapping[cst.alias] = alias
return mapping return mapping
@transaction.atomic def apply_mapping(self, mapping: dict[str, str], change_aliases: bool = False) -> None:
def apply_mapping(self, mapping: dict[str, str], change_aliases: bool = False):
''' Apply rename mapping. ''' ''' Apply rename mapping. '''
cst_list = self.constituents().order_by('order') self.cache.ensure()
for cst in cst_list: update_list: list[Constituenta] = []
for cst in self.cache.constituents:
if cst.apply_mapping(mapping, change_aliases): if cst.apply_mapping(mapping, change_aliases):
cst.save() update_list.append(cst)
Constituenta.objects.bulk_update(update_list, ['alias', 'definition_formal', 'term_raw', 'definition_raw'])
self.save()
@transaction.atomic def resolve_all_text(self) -> None:
def resolve_all_text(self):
''' Trigger reference resolution for all texts. ''' ''' Trigger reference resolution for all texts. '''
self.cache.ensure()
graph_terms = self._graph_term() graph_terms = self._graph_term()
resolver = Resolver({}) resolver = Resolver({})
update_list: list[Constituenta] = []
for cst_id in graph_terms.topological_order(): for cst_id in graph_terms.topological_order():
cst = self.constituents().get(id=cst_id) cst = self.cache.by_id[cst_id]
resolved = resolver.resolve(cst.term_raw) resolved = resolver.resolve(cst.term_raw)
resolver.context[cst.alias] = Entity(cst.alias, resolved) resolver.context[cst.alias] = Entity(cst.alias, resolved)
if resolved != cst.term_resolved: cst.term_resolved = resolved
cst.term_resolved = resolved update_list.append(cst)
cst.save() Constituenta.objects.bulk_update(update_list, ['term_resolved'])
for cst in self.constituents():
resolved = resolver.resolve(cst.definition_raw) for cst in self.cache.constituents:
if resolved != cst.definition_resolved: resolved = resolver.resolve(cst.definition_raw)
cst.definition_resolved = resolved cst.definition_resolved = resolved
cst.save() Constituenta.objects.bulk_update(self.cache.constituents, ['definition_resolved'])
@transaction.atomic
def create_version(self, version: str, description: str, data) -> Version: def create_version(self, version: str, description: str, data) -> Version:
''' Creates version for current state. ''' ''' Creates version for current state. '''
return Version.objects.create( return Version.objects.create(
@ -309,7 +381,6 @@ class RSForm:
data=data data=data
) )
@transaction.atomic
def produce_structure(self, target: Constituenta, parse: dict) -> list[int]: def produce_structure(self, target: Constituenta, parse: dict) -> list[int]:
''' Add constituents for each structural element of the target. ''' ''' Add constituents for each structural element of the target. '''
expressions = generate_structure( expressions = generate_structure(
@ -320,9 +391,10 @@ class RSForm:
count_new = len(expressions) count_new = len(expressions)
if count_new == 0: if count_new == 0:
return [] return []
position = target.order + 1
self._shift_positions(position, count_new)
position = target.order + 1
self.cache.ensure()
self._shift_positions(position, count_new)
result = [] result = []
cst_type = CstType.TERM if len(parse['args']) == 0 else CstType.FUNCTION cst_type = CstType.TERM if len(parse['args']) == 0 else CstType.FUNCTION
free_index = self.get_max_index(cst_type) + 1 free_index = self.get_max_index(cst_type) + 1
@ -339,97 +411,86 @@ class RSForm:
free_index = free_index + 1 free_index = free_index + 1
position = position + 1 position = position + 1
self.cache.clear()
self.save() self.save()
return result return result
def _shift_positions(self, start: int, shift: int): def _shift_positions(self, start: int, shift: int) -> None:
if shift == 0: if shift == 0:
return return
update_list = \ update_list: Iterable[Constituenta] = []
Constituenta.objects \ if not self.cache.is_loaded:
.only('order') \ update_list = Constituenta.objects \
.filter(schema=self.model, order__gte=start) .only('order') \
.filter(schema=self.model, order__gte=start)
else:
update_list = [cst for cst in self.cache.constituents if cst.order >= start]
for cst in update_list: for cst in update_list:
cst.order += shift cst.order += shift
Constituenta.objects.bulk_update(update_list, ['order']) Constituenta.objects.bulk_update(update_list, ['order'])
def _get_last_position(self):
if self.constituents().exists():
return self.constituents().count()
else:
return 0
def _get_insert_position(self, position: int) -> int: def _get_insert_position(self, position: int) -> int:
if position <= 0 and position != _INSERT_LAST: if position <= 0 and position != _INSERT_LAST:
raise ValidationError(msg.invalidPosition()) raise ValidationError(msg.invalidPosition())
lastPosition = self._get_last_position() lastPosition = self.constituents().count()
if position == _INSERT_LAST: if position == _INSERT_LAST:
position = lastPosition + 1 position = lastPosition + 1
else: else:
position = max(1, min(position, lastPosition + 1)) position = max(1, min(position, lastPosition + 1))
return position return position
@transaction.atomic def _reset_order(self) -> None:
def _reset_order(self):
order = 1 order = 1
for cst in self.constituents().only('order').order_by('order'): changed: list[Constituenta] = []
cst_list: Iterable[Constituenta] = []
if not self.cache.is_loaded:
cst_list = self.constituents().only('order').order_by('order')
else:
cst_list = self.cache.constituents
for cst in cst_list:
if cst.order != order: if cst.order != order:
cst.order = order cst.order = order
cst.save() changed.append(cst)
order += 1 order += 1
Constituenta.objects.bulk_update(changed, ['order'])
def _graph_formal(self) -> Graph[int]: def _graph_formal(self) -> Graph[int]:
''' Graph based on formal definitions. ''' ''' Graph based on formal definitions. '''
self.cache.ensure()
result: Graph[int] = Graph() result: Graph[int] = Graph()
cst_list = \ for cst in self.cache.constituents:
self.constituents() \
.only('alias', 'definition_formal') \
.order_by('order')
for cst in cst_list:
result.add_node(cst.pk) result.add_node(cst.pk)
for cst in cst_list: for cst in self.cache.constituents:
for alias in extract_globals(cst.definition_formal): for alias in extract_globals(cst.definition_formal):
try: child = self.cache.by_alias.get(alias)
child = cst_list.get(alias=alias) if child is not None:
result.add_edge(src=child.pk, dest=cst.pk) result.add_edge(src=child.pk, dest=cst.pk)
except Constituenta.DoesNotExist:
pass
return result return result
def _graph_term(self) -> Graph[int]: def _graph_term(self) -> Graph[int]:
''' Graph based on term texts. ''' ''' Graph based on term texts. '''
self.cache.ensure()
result: Graph[int] = Graph() result: Graph[int] = Graph()
cst_list = \ for cst in self.cache.constituents:
self.constituents() \
.only('alias', 'term_raw') \
.order_by('order')
for cst in cst_list:
result.add_node(cst.pk) result.add_node(cst.pk)
for cst in cst_list: for cst in self.cache.constituents:
for alias in extract_entities(cst.term_raw): for alias in extract_entities(cst.term_raw):
try: child = self.cache.by_alias.get(alias)
child = cst_list.get(alias=alias) if child is not None:
result.add_edge(src=child.pk, dest=cst.pk) result.add_edge(src=child.pk, dest=cst.pk)
except Constituenta.DoesNotExist:
pass
return result return result
def _graph_text(self) -> Graph[int]: def _graph_text(self) -> Graph[int]:
''' Graph based on definition texts. ''' ''' Graph based on definition texts. '''
self.cache.ensure()
result: Graph[int] = Graph() result: Graph[int] = Graph()
cst_list = \ for cst in self.cache.constituents:
self.constituents() \
.only('alias', 'definition_raw') \
.order_by('order')
for cst in cst_list:
result.add_node(cst.pk) result.add_node(cst.pk)
for cst in cst_list: for cst in self.cache.constituents:
for alias in extract_entities(cst.definition_raw): for alias in extract_entities(cst.definition_raw):
try: child = self.cache.by_alias.get(alias)
child = cst_list.get(alias=alias) if child is not None:
result.add_edge(src=child.pk, dest=cst.pk) result.add_edge(src=child.pk, dest=cst.pk)
except Constituenta.DoesNotExist:
pass
return result return result
@ -437,14 +498,11 @@ class SemanticInfo:
''' Semantic information derived from constituents. ''' ''' Semantic information derived from constituents. '''
def __init__(self, schema: RSForm): def __init__(self, schema: RSForm):
schema.cache.ensure()
self._graph = schema._graph_formal() self._graph = schema._graph_formal()
self._items = list( self._items = schema.cache.constituents
schema.constituents() self._cst_by_ID = schema.cache.by_id
.only('alias', 'cst_type', 'definition_formal') self._cst_by_alias = schema.cache.by_alias
.order_by('order')
)
self._cst_by_alias = {cst.alias: cst for cst in self._items}
self._cst_by_ID = {cst.pk: cst for cst in self._items}
self.info = { self.info = {
cst.pk: { cst.pk: {
'is_simple': False, 'is_simple': False,
@ -452,7 +510,7 @@ class SemanticInfo:
'parent': cst.pk, 'parent': cst.pk,
'children': [] 'children': []
} }
for cst in self._items for cst in schema.cache.constituents
} }
self._calculate_attributes() self._calculate_attributes()
@ -475,7 +533,7 @@ class SemanticInfo:
''' Access "children" attribute. ''' ''' Access "children" attribute. '''
return cast(list[int], self.info[target]['children']) return cast(list[int], self.info[target]['children'])
def _calculate_attributes(self): def _calculate_attributes(self) -> None:
for cst_id in self._graph.topological_order(): for cst_id in self._graph.topological_order():
cst = self._cst_by_ID[cst_id] cst = self._cst_by_ID[cst_id]
self.info[cst_id]['is_template'] = infer_template(cst.definition_formal) self.info[cst_id]['is_template'] = infer_template(cst.definition_formal)
@ -485,7 +543,7 @@ class SemanticInfo:
parent = self._infer_parent(cst) parent = self._infer_parent(cst)
self.info[cst_id]['parent'] = parent self.info[cst_id]['parent'] = parent
if parent != cst_id: if parent != cst_id:
self.info[parent]['children'].append(cst_id) cast(list[int], self.info[parent]['children']).append(cst_id)
def _infer_simple_expression(self, target: Constituenta) -> bool: def _infer_simple_expression(self, target: Constituenta) -> bool:
if target.cst_type == CstType.STRUCTURED or is_base_set(target.cst_type): if target.cst_type == CstType.STRUCTURED or is_base_set(target.cst_type):
@ -565,12 +623,8 @@ class _OrderManager:
def __init__(self, schema: RSForm): def __init__(self, schema: RSForm):
self._semantic = schema.semantic() self._semantic = schema.semantic()
self._graph = schema._graph_formal() self._graph = schema._graph_formal()
self._items = list( self._items = schema.cache.constituents
schema.constituents() self._cst_by_ID = schema.cache.by_id
.only('order', 'alias', 'cst_type', 'definition_formal')
.order_by('order')
)
self._cst_by_ID = {cst.pk: cst for cst in self._items}
def restore_order(self) -> None: def restore_order(self) -> None:
''' Implement order restoration process. ''' ''' Implement order restoration process. '''
@ -615,10 +669,9 @@ class _OrderManager:
result.append(child) result.append(child)
self._items = result self._items = result
@transaction.atomic
def _save_order(self) -> None: def _save_order(self) -> None:
order = 1 order = 1
for cst in self._items: for cst in self._items:
cst.order = order cst.order = order
cst.save()
order += 1 order += 1
Constituenta.objects.bulk_update(self._items, ['order'])

View File

@ -3,7 +3,6 @@ from typing import Optional, cast
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import transaction
from django.db.models import Q from django.db.models import Q
from rest_framework import serializers from rest_framework import serializers
from rest_framework.serializers import PrimaryKeyRelatedField as PKField from rest_framework.serializers import PrimaryKeyRelatedField as PKField
@ -72,7 +71,12 @@ class CstDetailsSerializer(serializers.ModelSerializer):
class CstCreateSerializer(serializers.ModelSerializer): class CstCreateSerializer(serializers.ModelSerializer):
''' Serializer: Constituenta creation. ''' ''' Serializer: Constituenta creation. '''
insert_after = serializers.IntegerField(required=False, allow_null=True) insert_after = PKField(
many=False,
allow_null=True,
required=False,
queryset=Constituenta.objects.all().only('schema_id', 'order')
)
alias = serializers.CharField(max_length=8) alias = serializers.CharField(max_length=8)
cst_type = serializers.ChoiceField(CstType.choices) cst_type = serializers.ChoiceField(CstType.choices)
@ -149,7 +153,6 @@ class RSFormSerializer(serializers.ModelSerializer):
result['version'] = version result['version'] = version
return result | data return result | data
@transaction.atomic
def restore_from_version(self, data: dict): def restore_from_version(self, data: dict):
''' Load data from version. ''' ''' Load data from version. '''
schema = RSForm(cast(LibraryItem, self.instance)) schema = RSForm(cast(LibraryItem, self.instance))
@ -312,9 +315,9 @@ class CstSubstituteSerializer(serializers.Serializer):
raise serializers.ValidationError({ raise serializers.ValidationError({
f'{original_cst.pk}': msg.substituteDouble(original_cst.alias) f'{original_cst.pk}': msg.substituteDouble(original_cst.alias)
}) })
if original_cst.alias == substitution_cst.alias: if original_cst.pk == substitution_cst.pk:
raise serializers.ValidationError({ raise serializers.ValidationError({
'alias': msg.substituteTrivial(original_cst.alias) 'original': msg.substituteTrivial(original_cst.alias)
}) })
if original_cst.schema_id != schema.pk: if original_cst.schema_id != schema.pk:
raise serializers.ValidationError({ raise serializers.ValidationError({

View File

@ -1,15 +1,16 @@
''' Testing models: api_RSForm. ''' ''' Testing models: api_RSForm. '''
from django.forms import ValidationError from django.forms import ValidationError
from django.test import TestCase
from apps.rsform.models import Constituenta, CstType, RSForm from apps.rsform.models import Constituenta, CstType, RSForm
from apps.users.models import User from apps.users.models import User
from shared.DBTester import DBTester
class TestRSForm(TestCase): class TestRSForm(DBTester):
''' Testing RSForm wrapper. ''' ''' Testing RSForm wrapper. '''
def setUp(self): def setUp(self):
super().setUp()
self.user1 = User.objects.create(username='User1') self.user1 = User.objects.create(username='User1')
self.user2 = User.objects.create(username='User2') self.user2 = User.objects.create(username='User2')
self.schema = RSForm.create(title='Test') self.schema = RSForm.create(title='Test')
@ -180,7 +181,6 @@ class TestRSForm(TestCase):
alias='D1', alias='D1',
definition_formal='X1 = X11 = X2', definition_formal='X1 = X11 = X2',
definition_raw='@{X11|sing}', definition_raw='@{X11|sing}',
convention='X1',
term_raw='@{X1|plur}' term_raw='@{X1|plur}'
) )
@ -188,7 +188,6 @@ class TestRSForm(TestCase):
d1.refresh_from_db() d1.refresh_from_db()
self.assertEqual(d1.definition_formal, 'X3 = X4 = X2', msg='Map IDs in expression') self.assertEqual(d1.definition_formal, 'X3 = X4 = X2', msg='Map IDs in expression')
self.assertEqual(d1.definition_raw, '@{X4|sing}', msg='Map IDs in definition') self.assertEqual(d1.definition_raw, '@{X4|sing}', msg='Map IDs in definition')
self.assertEqual(d1.convention, 'X3', msg='Map IDs in convention')
self.assertEqual(d1.term_raw, '@{X3|plur}', msg='Map IDs in term') self.assertEqual(d1.term_raw, '@{X3|plur}', msg='Map IDs in term')
self.assertEqual(d1.term_resolved, '', msg='Do not run resolve on mapping') self.assertEqual(d1.term_resolved, '', msg='Do not run resolve on mapping')
self.assertEqual(d1.definition_resolved, '', msg='Do not run resolve on mapping') self.assertEqual(d1.definition_resolved, '', msg='Do not run resolve on mapping')
@ -320,7 +319,6 @@ class TestRSForm(TestCase):
x2 = self.schema.insert_new('X21') x2 = self.schema.insert_new('X21')
d1 = self.schema.insert_new( d1 = self.schema.insert_new(
alias='D11', alias='D11',
convention='D11 - cool',
definition_formal='X21=X21', definition_formal='X21=X21',
term_raw='@{X21|sing}', term_raw='@{X21|sing}',
definition_raw='@{X11|datv}', definition_raw='@{X11|datv}',
@ -335,7 +333,6 @@ class TestRSForm(TestCase):
self.assertEqual(x1.alias, 'X1') self.assertEqual(x1.alias, 'X1')
self.assertEqual(x2.alias, 'X2') self.assertEqual(x2.alias, 'X2')
self.assertEqual(d1.alias, 'D1') self.assertEqual(d1.alias, 'D1')
self.assertEqual(d1.convention, 'D1 - cool')
self.assertEqual(d1.term_raw, '@{X2|sing}') self.assertEqual(d1.term_raw, '@{X2|sing}')
self.assertEqual(d1.definition_raw, '@{X1|datv}') self.assertEqual(d1.definition_raw, '@{X1|datv}')
self.assertEqual(d1.definition_resolved, 'test') self.assertEqual(d1.definition_resolved, 'test')

View File

@ -74,22 +74,20 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['post'], url_path='create-cst') @action(detail=True, methods=['post'], url_path='create-cst')
def create_cst(self, request: Request, pk): def create_cst(self, request: Request, pk) -> HttpResponse:
''' Create new constituenta. ''' ''' Create new constituenta. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstCreateSerializer(data=request.data) serializer = s.CstCreateSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
data = serializer.validated_data data = serializer.validated_data
if 'insert_after' in data and data['insert_after'] is not None: if 'insert_after' not in data:
try:
insert_after = m.Constituenta.objects.get(pk=data['insert_after'])
except LibraryItem.DoesNotExist:
return Response(status=c.HTTP_404_NOT_FOUND)
else:
insert_after = None insert_after = None
new_cst = m.RSForm(schema).create_cst(data, insert_after) else:
insert_after = data['insert_after']
with transaction.atomic():
new_cst = m.RSForm(schema).create_cst(data, insert_after)
schema.refresh_from_db()
return Response( return Response(
status=c.HTTP_201_CREATED, status=c.HTTP_201_CREATED,
data={ data={
@ -110,7 +108,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='update-cst') @action(detail=True, methods=['patch'], url_path='update-cst')
def update_cst(self, request: Request, pk): def update_cst(self, request: Request, pk) -> HttpResponse:
''' Update persistent attributes of a given constituenta. ''' ''' Update persistent attributes of a given constituenta. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstSerializer(data=request.data, partial=True) serializer = s.CstSerializer(data=request.data, partial=True)
@ -140,7 +138,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='produce-structure') @action(detail=True, methods=['patch'], url_path='produce-structure')
def produce_structure(self, request: Request, pk): def produce_structure(self, request: Request, pk) -> HttpResponse:
''' Produce a term for every element of the target constituenta typification. ''' ''' Produce a term for every element of the target constituenta typification. '''
schema = self._get_item() schema = self._get_item()
@ -159,8 +157,8 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
status=c.HTTP_400_BAD_REQUEST, status=c.HTTP_400_BAD_REQUEST,
data={f'{cst.pk}': msg.constituentaNoStructure()} data={f'{cst.pk}': msg.constituentaNoStructure()}
) )
with transaction.atomic():
result = m.RSForm(schema).produce_structure(cst, cst_parse) result = m.RSForm(schema).produce_structure(cst, cst_parse)
return Response( return Response(
status=c.HTTP_200_OK, status=c.HTTP_200_OK,
data={ data={
@ -181,21 +179,20 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='rename-cst') @action(detail=True, methods=['patch'], url_path='rename-cst')
def rename_cst(self, request: Request, pk): def rename_cst(self, request: Request, pk) -> HttpResponse:
''' Rename constituenta possibly changing type. ''' ''' Rename constituenta possibly changing type. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstRenameSerializer(data=request.data, context={'schema': schema}) serializer = s.CstRenameSerializer(data=request.data, context={'schema': schema})
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
cst = cast(m.Constituenta, serializer.validated_data['target']) cst = cast(m.Constituenta, serializer.validated_data['target'])
old_alias = cst.alias mapping = {cst.alias: serializer.validated_data['alias']}
cst.alias = serializer.validated_data['alias'] cst.alias = serializer.validated_data['alias']
cst.cst_type = serializer.validated_data['cst_type'] cst.cst_type = serializer.validated_data['cst_type']
with transaction.atomic(): with transaction.atomic():
cst.save() cst.save()
m.RSForm(schema).apply_mapping(mapping={old_alias: cst.alias}, change_aliases=False) m.RSForm(schema).apply_mapping(mapping=mapping, change_aliases=False)
schema.refresh_from_db() schema.refresh_from_db()
cst.refresh_from_db() cst.refresh_from_db()
@ -219,7 +216,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='substitute') @action(detail=True, methods=['patch'], url_path='substitute')
def substitute(self, request: Request, pk): def substitute(self, request: Request, pk) -> HttpResponse:
''' Substitute occurrences of constituenta with another one. ''' ''' Substitute occurrences of constituenta with another one. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstSubstituteSerializer( serializer = s.CstSubstituteSerializer(
@ -252,7 +249,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='delete-multiple-cst') @action(detail=True, methods=['patch'], url_path='delete-multiple-cst')
def delete_multiple_cst(self, request: Request, pk): def delete_multiple_cst(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Delete multiple constituents. ''' ''' Endpoint: Delete multiple constituents. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstListSerializer( serializer = s.CstListSerializer(
@ -260,9 +257,8 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
context={'schema': schema} context={'schema': schema}
) )
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
m.RSForm(schema).delete_cst(serializer.validated_data['items']) with transaction.atomic():
m.RSForm(schema).delete_cst(serializer.validated_data['items'])
schema.refresh_from_db()
return Response( return Response(
status=c.HTTP_200_OK, status=c.HTTP_200_OK,
data=s.RSFormParseSerializer(schema).data data=s.RSFormParseSerializer(schema).data
@ -280,7 +276,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='move-cst') @action(detail=True, methods=['patch'], url_path='move-cst')
def move_cst(self, request: Request, pk): def move_cst(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Move multiple constituents. ''' ''' Endpoint: Move multiple constituents. '''
schema = self._get_item() schema = self._get_item()
serializer = s.CstMoveSerializer( serializer = s.CstMoveSerializer(
@ -288,10 +284,11 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
context={'schema': schema} context={'schema': schema}
) )
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
m.RSForm(schema).move_cst( with transaction.atomic():
listCst=serializer.validated_data['items'], m.RSForm(schema).move_cst(
target=serializer.validated_data['move_to'] target=serializer.validated_data['items'],
) destination=serializer.validated_data['move_to']
)
return Response( return Response(
status=c.HTTP_200_OK, status=c.HTTP_200_OK,
data=s.RSFormParseSerializer(schema).data data=s.RSFormParseSerializer(schema).data
@ -308,7 +305,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='reset-aliases') @action(detail=True, methods=['patch'], url_path='reset-aliases')
def reset_aliases(self, request: Request, pk): def reset_aliases(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Recreate all aliases based on order. ''' ''' Endpoint: Recreate all aliases based on order. '''
schema = self._get_item() schema = self._get_item()
m.RSForm(schema).reset_aliases() m.RSForm(schema).reset_aliases()
@ -328,7 +325,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='restore-order') @action(detail=True, methods=['patch'], url_path='restore-order')
def restore_order(self, request: Request, pk): def restore_order(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Restore order based on types and term graph. ''' ''' Endpoint: Restore order based on types and term graph. '''
schema = self._get_item() schema = self._get_item()
m.RSForm(schema).restore_order() m.RSForm(schema).restore_order()
@ -349,7 +346,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['patch'], url_path='load-trs') @action(detail=True, methods=['patch'], url_path='load-trs')
def load_trs(self, request: Request, pk): def load_trs(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Load data from file and replace current schema. ''' ''' Endpoint: Load data from file and replace current schema. '''
input_serializer = s.RSFormUploadSerializer(data=request.data) input_serializer = s.RSFormUploadSerializer(data=request.data)
input_serializer.is_valid(raise_exception=True) input_serializer.is_valid(raise_exception=True)
@ -380,7 +377,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['get'], url_path='contents') @action(detail=True, methods=['get'], url_path='contents')
def contents(self, request: Request, pk): def contents(self, request: Request, pk) -> HttpResponse:
''' Endpoint: View schema db contents (including constituents). ''' ''' Endpoint: View schema db contents (including constituents). '''
serializer = s.RSFormSerializer(self.get_object()) serializer = s.RSFormSerializer(self.get_object())
return Response( return Response(
@ -398,7 +395,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['get'], url_path='details') @action(detail=True, methods=['get'], url_path='details')
def details(self, request: Request, pk): def details(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Detailed schema view including statuses and parse. ''' ''' Endpoint: Detailed schema view including statuses and parse. '''
serializer = s.RSFormParseSerializer(self.get_object()) serializer = s.RSFormParseSerializer(self.get_object())
return Response( return Response(
@ -416,7 +413,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
}, },
) )
@action(detail=True, methods=['post'], url_path='check') @action(detail=True, methods=['post'], url_path='check')
def check(self, request: Request, pk): def check(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Check RSLang expression against schema context. ''' ''' Endpoint: Check RSLang expression against schema context. '''
serializer = s.ExpressionSerializer(data=request.data) serializer = s.ExpressionSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -438,7 +435,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['post'], url_path='resolve') @action(detail=True, methods=['post'], url_path='resolve')
def resolve(self, request: Request, pk): def resolve(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Resolve references in text against schema terms context. ''' ''' Endpoint: Resolve references in text against schema terms context. '''
serializer = s.TextSerializer(data=request.data) serializer = s.TextSerializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
@ -460,7 +457,7 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
} }
) )
@action(detail=True, methods=['get'], url_path='export-trs') @action(detail=True, methods=['get'], url_path='export-trs')
def export_trs(self, request: Request, pk): def export_trs(self, request: Request, pk) -> HttpResponse:
''' Endpoint: Download Exteor compatible file. ''' ''' Endpoint: Download Exteor compatible file. '''
schema = self._get_item() schema = self._get_item()
data = s.RSFormTRSSerializer(m.RSForm(schema)).data data = s.RSFormTRSSerializer(m.RSForm(schema)).data
@ -485,7 +482,7 @@ class TrsImportView(views.APIView):
c.HTTP_403_FORBIDDEN: None c.HTTP_403_FORBIDDEN: None
} }
) )
def post(self, request: Request): def post(self, request: Request) -> HttpResponse:
data = utility.read_zipped_json(request.FILES['file'].file, utils.EXTEOR_INNER_FILENAME) data = utility.read_zipped_json(request.FILES['file'].file, utils.EXTEOR_INNER_FILENAME)
owner = cast(User, self.request.user) owner = cast(User, self.request.user)
_prepare_rsform_data(data, request, owner) _prepare_rsform_data(data, request, owner)
@ -512,7 +509,7 @@ class TrsImportView(views.APIView):
} }
) )
@api_view(['POST']) @api_view(['POST'])
def create_rsform(request: Request): def create_rsform(request: Request) -> HttpResponse:
''' Endpoint: Create RSForm from user input and/or trs file. ''' ''' Endpoint: Create RSForm from user input and/or trs file. '''
owner = cast(User, request.user) if not request.user.is_anonymous else None owner = cast(User, request.user) if not request.user.is_anonymous else None
if 'file' not in request.FILES: if 'file' not in request.FILES:
@ -564,7 +561,7 @@ def _prepare_rsform_data(data: dict, request: Request, owner: Union[User, None])
responses={c.HTTP_200_OK: s.RSFormParseSerializer} responses={c.HTTP_200_OK: s.RSFormParseSerializer}
) )
@api_view(['PATCH']) @api_view(['PATCH'])
def inline_synthesis(request: Request): def inline_synthesis(request: Request) -> HttpResponse:
''' Endpoint: Inline synthesis. ''' ''' Endpoint: Inline synthesis. '''
serializer = s.InlineSynthesisSerializer( serializer = s.InlineSynthesisSerializer(
data=request.data, data=request.data,
@ -581,10 +578,10 @@ def inline_synthesis(request: Request):
original = cast(m.Constituenta, substitution['original']) original = cast(m.Constituenta, substitution['original'])
replacement = cast(m.Constituenta, substitution['substitution']) replacement = cast(m.Constituenta, substitution['substitution'])
if original in items: if original in items:
index = next(i for (i, cst) in enumerate(items) if cst == original) index = next(i for (i, cst) in enumerate(items) if cst.pk == original.pk)
original = new_items[index] original = new_items[index]
else: else:
index = next(i for (i, cst) in enumerate(items) if cst == replacement) index = next(i for (i, cst) in enumerate(items) if cst.pk == replacement.pk)
replacement = new_items[index] replacement = new_items[index]
receiver.substitute(original, replacement) receiver.substitute(original, replacement)
receiver.restore_order() receiver.restore_order()

View File

@ -0,0 +1,23 @@
''' Utils: tester for database operations. '''
import logging
from django.db import connection
from rest_framework.test import APITestCase
class DBTester(APITestCase):
''' Abstract base class for Testing database. '''
def setUp(self):
self.logger = logging.getLogger('django.db.backends')
self.logger.setLevel(logging.DEBUG)
def start_db_log(self):
''' Warning! Do not use this second time before calling stop_db_log. '''
''' Warning! Do not forget to enable global logging in settings. '''
logging.disable(logging.NOTSET)
connection.force_debug_cursor = True
def stop_db_log(self):
connection.force_debug_cursor = False
logging.disable(logging.CRITICAL)

View File

@ -1,13 +1,12 @@
''' Utils: base tester class for endpoints. ''' ''' Utils: base tester class for endpoints. '''
import logging
from django.db import connection
from rest_framework import status from rest_framework import status
from rest_framework.test import APIClient, APIRequestFactory, APITestCase from rest_framework.test import APIClient, APIRequestFactory
from apps.library.models import Editor, LibraryItem from apps.library.models import Editor, LibraryItem
from apps.users.models import User from apps.users.models import User
from .DBTester import DBTester
def decl_endpoint(endpoint: str, method: str): def decl_endpoint(endpoint: str, method: str):
''' Decorator for EndpointTester methods to provide API attributes. ''' ''' Decorator for EndpointTester methods to provide API attributes. '''
@ -25,10 +24,11 @@ def decl_endpoint(endpoint: str, method: str):
return set_endpoint_inner return set_endpoint_inner
class EndpointTester(APITestCase): class EndpointTester(DBTester):
''' Abstract base class for Testing endpoints. ''' ''' Abstract base class for Testing endpoints. '''
def setUp(self): def setUp(self):
super().setUp()
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
self.user = User.objects.create( self.user = User.objects.create(
username='UserTest', username='UserTest',
@ -43,9 +43,6 @@ class EndpointTester(APITestCase):
self.client = APIClient() self.client = APIClient()
self.client.force_authenticate(user=self.user) self.client.force_authenticate(user=self.user)
self.logger = logging.getLogger('django.db.backends')
self.logger.setLevel(logging.DEBUG)
def setUpFullUsers(self): def setUpFullUsers(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
self.user = User.objects.create_user( self.user = User.objects.create_user(
@ -77,16 +74,6 @@ class EndpointTester(APITestCase):
def logout(self): def logout(self):
self.client.logout() self.client.logout()
def start_db_log(self):
''' Warning! Do not use this second time before calling stop_db_log. '''
''' Warning! Do not forget to enable global logging in settings. '''
logging.disable(logging.NOTSET)
connection.force_debug_cursor = True
def stop_db_log(self):
connection.force_debug_cursor = False
logging.disable(logging.CRITICAL)
def set_params(self, **kwargs): def set_params(self, **kwargs):
''' Given named argument values resolve current endpoint_mask. ''' ''' Given named argument values resolve current endpoint_mask. '''
if self.endpoint_mask and len(kwargs) > 0: if self.endpoint_mask and len(kwargs) > 0: