R: Cleanup model API: move logic to view/serializer

This commit is contained in:
Ivan 2024-08-08 15:31:32 +03:00
parent 1647b3c0e8
commit 5c4c0b38d5
14 changed files with 88 additions and 98 deletions

View File

@ -1,7 +1,6 @@
''' Models: LibraryItem. ''' ''' Models: LibraryItem. '''
import re import re
from django.db import transaction
from django.db.models import ( from django.db.models import (
SET_NULL, SET_NULL,
BooleanField, BooleanField,
@ -16,7 +15,6 @@ from django.db.models import (
from apps.users.models import User from apps.users.models import User
from .Subscription import Subscription
from .Version import Version from .Version import Version
@ -125,34 +123,3 @@ class LibraryItem(Model):
def versions(self) -> QuerySet[Version]: def versions(self) -> QuerySet[Version]:
''' Get all Versions of this item. ''' ''' Get all Versions of this item. '''
return Version.objects.filter(item=self.pk).order_by('-time_create') return Version.objects.filter(item=self.pk).order_by('-time_create')
# TODO: move to View layer
@transaction.atomic
def save(self, *args, **kwargs):
''' Save updating subscriptions and connected operations. '''
if not self._state.adding:
self._update_connected_operations()
subscribe = self._state.adding and self.owner
super().save(*args, **kwargs)
if subscribe:
Subscription.subscribe(user=self.owner_id, item=self.pk)
def _update_connected_operations(self):
# using method level import to prevent circular dependency
from apps.oss.models import Operation # pylint: disable=import-outside-toplevel
operations = Operation.objects.filter(result__pk=self.pk)
if not operations.exists():
return
for operation in operations:
changed = False
if operation.alias != self.alias:
operation.alias = self.alias
changed = True
if operation.title != self.title:
operation.title = self.title
changed = True
if operation.comment != self.comment:
operation.comment = self.comment
changed = True
if changed:
operation.save()

View File

@ -68,7 +68,6 @@ class TestLibraryItem(TestCase):
self.assertEqual(item.alias, 'KS1') self.assertEqual(item.alias, 'KS1')
self.assertEqual(item.comment, 'Test comment') self.assertEqual(item.comment, 'Test comment')
self.assertEqual(item.location, LocationHead.COMMON) self.assertEqual(item.location, LocationHead.COMMON)
self.assertTrue(Subscription.objects.filter(user=item.owner, item=item).exists())
class TestLocation(TestCase): class TestLocation(TestCase):

View File

@ -21,9 +21,7 @@ class TestSubscription(TestCase):
def test_default(self): def test_default(self):
subs = list(Subscription.objects.filter(item=self.item)) subs = list(Subscription.objects.filter(item=self.item))
self.assertEqual(len(subs), 1) self.assertEqual(len(subs), 0)
self.assertEqual(subs[0].item, self.item)
self.assertEqual(subs[0].user, self.user1)
def test_str(self): def test_str(self):

View File

@ -49,6 +49,7 @@ class TestLibraryViewset(EndpointTester):
self.assertEqual(response.data['item_type'], LibraryItemType.RSFORM) self.assertEqual(response.data['item_type'], LibraryItemType.RSFORM)
self.assertEqual(response.data['title'], data['title']) self.assertEqual(response.data['title'], data['title'])
self.assertEqual(response.data['alias'], data['alias']) self.assertEqual(response.data['alias'], data['alias'])
self.assertTrue(Subscription.objects.filter(user=self.user, item_id=response.data['id']).exists())
data = { data = {
'item_type': LibraryItemType.OPERATION_SCHEMA, 'item_type': LibraryItemType.OPERATION_SCHEMA,
@ -74,7 +75,7 @@ class TestLibraryViewset(EndpointTester):
@decl_endpoint('/api/library/{item}', method='patch') @decl_endpoint('/api/library/{item}', method='patch')
def test_update(self): def test_update(self):
data = {'id': self.unowned.pk, 'title': 'New Title'} data = {'title': 'New Title'}
self.executeNotFound(data=data, item=self.invalid_item) self.executeNotFound(data=data, item=self.invalid_item)
self.executeForbidden(data=data, item=self.unowned.pk) self.executeForbidden(data=data, item=self.unowned.pk)
@ -86,13 +87,12 @@ class TestLibraryViewset(EndpointTester):
self.unowned.save() self.unowned.save()
self.executeForbidden(data=data, item=self.unowned.pk) self.executeForbidden(data=data, item=self.unowned.pk)
data = {'id': self.owned.pk, 'title': 'New Title'} data = {'title': 'New Title'}
response = self.executeOK(data=data, item=self.owned.pk) response = self.executeOK(data=data, item=self.owned.pk)
self.assertEqual(response.data['title'], data['title']) self.assertEqual(response.data['title'], data['title'])
self.assertEqual(response.data['alias'], self.owned.alias) self.assertEqual(response.data['alias'], self.owned.alias)
data = { data = {
'id': self.owned.pk,
'title': 'Another Title', 'title': 'Another Title',
'owner': self.user2.pk, 'owner': self.user2.pk,
'access_policy': AccessPolicy.PROTECTED, 'access_policy': AccessPolicy.PROTECTED,

View File

@ -13,7 +13,7 @@ from rest_framework.decorators import action
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from apps.oss.models import OperationSchema from apps.oss.models import Operation, OperationSchema
from apps.rsform.models import RSForm from apps.rsform.models import RSForm
from apps.rsform.serializers import RSFormParseSerializer from apps.rsform.serializers import RSFormParseSerializer
from apps.users.models import User from apps.users.models import User
@ -37,11 +37,35 @@ class LibraryViewSet(viewsets.ModelViewSet):
return s.LibraryItemBaseSerializer return s.LibraryItemBaseSerializer
return s.LibraryItemSerializer return s.LibraryItemSerializer
def perform_create(self, serializer): def perform_create(self, serializer) -> None:
if not self.request.user.is_anonymous and 'owner' not in self.request.POST: if not self.request.user.is_anonymous and 'owner' not in self.request.POST:
return serializer.save(owner=self.request.user) instance = serializer.save(owner=self.request.user)
else: else:
return serializer.save() instance = serializer.save()
if instance.owner:
m.Subscription.subscribe(user=instance.owner_id, item=instance.pk)
def perform_update(self, serializer) -> None:
instance = serializer.save()
operations = Operation.objects.filter(result__pk=instance.pk)
if not operations.exists():
return
update_list: list[Operation] = []
for operation in operations:
changed = False
if operation.alias != instance.alias:
operation.alias = instance.alias
changed = True
if operation.title != instance.title:
operation.title = instance.title
changed = True
if operation.comment != instance.comment:
operation.comment = instance.comment
changed = True
if changed:
update_list.append(operation)
if update_list:
Operation.objects.bulk_update(update_list, ['alias', 'title', 'comment'])
def get_permissions(self): def get_permissions(self):
if self.action in ['update', 'partial_update']: if self.action in ['update', 'partial_update']:

View File

@ -4,7 +4,7 @@ from typing import Optional
from django.db.models import QuerySet from django.db.models import QuerySet
from apps.library.models import Editor, LibraryItem, LibraryItemType from apps.library.models import Editor, LibraryItem, LibraryItemType
from apps.rsform.models import RSForm from apps.rsform.models import Constituenta, RSForm
from .Argument import Argument from .Argument import Argument
from .Inheritance import Inheritance from .Inheritance import Inheritance
@ -186,10 +186,12 @@ class OperationSchema:
parents[cst.pk] = items[i] parents[cst.pk] = items[i]
children[items[i].pk] = cst children[items[i].pk] = cst
translated_substitutions: list[tuple[Constituenta, Constituenta]] = []
for sub in substitutions: for sub in substitutions:
original = children[sub.original.pk] original = children[sub.original.pk]
replacement = children[sub.substitution.pk] replacement = children[sub.substitution.pk]
receiver.substitute(original, replacement) translated_substitutions.append((original, replacement))
receiver.substitute(translated_substitutions)
# TODO: remove duplicates from diamond # TODO: remove duplicates from diamond

View File

@ -31,36 +31,3 @@ class TestOperation(TestCase):
self.assertEqual(self.operation.comment, '') self.assertEqual(self.operation.comment, '')
self.assertEqual(self.operation.position_x, 0) self.assertEqual(self.operation.position_x, 0)
self.assertEqual(self.operation.position_y, 0) self.assertEqual(self.operation.position_y, 0)
def test_sync_from_result(self):
schema = RSForm.create(alias=self.operation.alias)
self.operation.result = schema.model
self.operation.save()
schema.model.alias = 'KS2'
schema.model.comment = 'Comment'
schema.model.title = 'Title'
schema.save()
self.operation.refresh_from_db()
self.assertEqual(self.operation.result, schema.model)
self.assertEqual(self.operation.alias, schema.model.alias)
self.assertEqual(self.operation.title, schema.model.title)
self.assertEqual(self.operation.comment, schema.model.comment)
def test_sync_from_library_item(self):
schema = LibraryItem.objects.create(alias=self.operation.alias, item_type=LibraryItemType.RSFORM)
self.operation.result = schema
self.operation.save()
schema.alias = 'KS2'
schema.comment = 'Comment'
schema.title = 'Title'
schema.save()
self.operation.refresh_from_db()
self.assertEqual(self.operation.result, schema)
self.assertEqual(self.operation.alias, schema.alias)
self.assertEqual(self.operation.title, schema.title)
self.assertEqual(self.operation.comment, schema.comment)

View File

@ -123,3 +123,33 @@ class TestChangeAttributes(EndpointTester):
self.assertEqual(list(self.ks1.model.editors()), [self.user, self.user2]) self.assertEqual(list(self.ks1.model.editors()), [self.user, self.user2])
self.assertEqual(list(self.ks2.model.editors()), []) self.assertEqual(list(self.ks2.model.editors()), [])
self.assertEqual(set(self.ks3.editors()), set([self.user, self.user3])) self.assertEqual(set(self.ks3.editors()), set([self.user, self.user3]))
@decl_endpoint('/api/library/{item}', method='patch')
def test_sync_from_result(self):
data = {'alias': 'KS111', 'title': 'New Title', 'comment': 'New Comment'}
self.executeOK(data=data, item=self.ks1.model.pk)
self.operation1.refresh_from_db()
self.assertEqual(self.operation1.result, self.ks1.model)
self.assertEqual(self.operation1.alias, data['alias'])
self.assertEqual(self.operation1.title, data['title'])
self.assertEqual(self.operation1.comment, data['comment'])
@decl_endpoint('/api/oss/{item}/update-operation', method='patch')
def test_sync_from_operation(self):
data = {
'target': self.operation3.pk,
'item_data': {
'alias': 'Test3 mod',
'title': 'Test title mod',
'comment': 'Comment mod'
},
'positions': [],
}
response = self.executeOK(data=data, item=self.owned_id)
self.ks3.refresh_from_db()
self.assertEqual(self.ks3.alias, data['item_data']['alias'])
self.assertEqual(self.ks3.title, data['item_data']['title'])
self.assertEqual(self.ks3.comment, data['item_data']['comment'])

View File

@ -79,14 +79,14 @@ class RSForm:
def remove(self, target: Constituenta) -> None: def remove(self, target: Constituenta) -> None:
if self.is_loaded: if self.is_loaded:
self.constituents.remove(target) self.constituents.remove(self.by_id[target.pk])
del self.by_id[target.pk] del self.by_id[target.pk]
del self.by_alias[target.alias] del self.by_alias[target.alias]
def remove_multi(self, target: Iterable[Constituenta]) -> None: def remove_multi(self, target: Iterable[Constituenta]) -> None:
if self.is_loaded: if self.is_loaded:
for cst in target: for cst in target:
self.constituents.remove(cst) self.constituents.remove(self.by_id[cst.pk])
del self.by_id[cst.pk] del self.by_id[cst.pk]
del self.by_alias[cst.alias] del self.by_alias[cst.alias]
@ -306,18 +306,20 @@ class RSForm:
self.resolve_all_text() self.resolve_all_text()
self.save() self.save()
def substitute( def substitute(self, substitutions: list[tuple[Constituenta, Constituenta]]) -> None:
self,
original: Constituenta,
substitution: Constituenta
) -> None:
''' Execute constituenta substitution. ''' ''' Execute constituenta substitution. '''
mapping = {}
deleted: list[Constituenta] = []
replacements: list[Constituenta] = []
for original, substitution in substitutions:
assert original.pk != substitution.pk assert original.pk != substitution.pk
mapping = {original.alias: substitution.alias} mapping[original.alias] = substitution.alias
deleted.append(original)
replacements.append(substitution)
self.cache.remove_multi(deleted)
Constituenta.objects.filter(pk__in=[cst.pk for cst in deleted]).delete()
self.apply_mapping(mapping) self.apply_mapping(mapping)
self.cache.remove(self.cache.by_id[original.pk]) self.on_term_change([substitution.pk for substitution in replacements])
original.delete()
self.on_term_change([substitution.pk])
def restore_order(self) -> None: def restore_order(self) -> None:
''' Restore order based on types and term graph. ''' ''' Restore order based on types and term graph. '''

View File

@ -207,7 +207,7 @@ class TestRSForm(DBTester):
definition_formal=x1.alias definition_formal=x1.alias
) )
self.schema.substitute(x1, x2) self.schema.substitute([(x1, x2)])
x2.refresh_from_db() x2.refresh_from_db()
d1.refresh_from_db() d1.refresh_from_db()
self.assertEqual(self.schema.constituents().count(), 2) self.assertEqual(self.schema.constituents().count(), 2)

View File

@ -100,7 +100,7 @@ class TestRSFormViewset(EndpointTester):
self.assertEqual(response.data['items'][1]['id'], x2.pk) self.assertEqual(response.data['items'][1]['id'], x2.pk)
self.assertEqual(response.data['items'][1]['term_raw'], x2.term_raw) self.assertEqual(response.data['items'][1]['term_raw'], x2.term_raw)
self.assertEqual(response.data['items'][1]['term_resolved'], x2.term_resolved) self.assertEqual(response.data['items'][1]['term_resolved'], x2.term_resolved)
self.assertEqual(response.data['subscribers'], [self.user.pk]) self.assertEqual(response.data['subscribers'], [])
self.assertEqual(response.data['editors'], []) self.assertEqual(response.data['editors'], [])
self.assertEqual(response.data['inheritance'], []) self.assertEqual(response.data['inheritance'], [])
self.assertEqual(response.data['oss'], []) self.assertEqual(response.data['oss'], [])

View File

@ -226,10 +226,12 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
with transaction.atomic(): with transaction.atomic():
substitutions: list[tuple[m.Constituenta, m.Constituenta]] = []
for substitution in serializer.validated_data['substitutions']: for substitution in serializer.validated_data['substitutions']:
original = cast(m.Constituenta, substitution['original']) original = cast(m.Constituenta, substitution['original'])
replacement = cast(m.Constituenta, substitution['substitution']) replacement = cast(m.Constituenta, substitution['substitution'])
m.RSForm(schema).substitute(original, replacement) substitutions.append((original, replacement))
m.RSForm(schema).substitute(substitutions)
schema.refresh_from_db() schema.refresh_from_db()
return Response( return Response(
@ -574,6 +576,7 @@ def inline_synthesis(request: Request) -> HttpResponse:
with transaction.atomic(): with transaction.atomic():
new_items = receiver.insert_copy(items) new_items = receiver.insert_copy(items)
substitutions: list[tuple[m.Constituenta, m.Constituenta]] = []
for substitution in serializer.validated_data['substitutions']: for substitution in serializer.validated_data['substitutions']:
original = cast(m.Constituenta, substitution['original']) original = cast(m.Constituenta, substitution['original'])
replacement = cast(m.Constituenta, substitution['substitution']) replacement = cast(m.Constituenta, substitution['substitution'])
@ -583,7 +586,8 @@ def inline_synthesis(request: Request) -> HttpResponse:
else: else:
index = next(i for (i, cst) in enumerate(items) if cst.pk == replacement.pk) index = next(i for (i, cst) in enumerate(items) if cst.pk == replacement.pk)
replacement = new_items[index] replacement = new_items[index]
receiver.substitute(original, replacement) substitutions.append((original, replacement))
receiver.substitute(substitutions)
receiver.restore_order() receiver.restore_order()
return Response( return Response(

View File

@ -134,7 +134,6 @@ class TestUserUserProfileAPIView(EndpointTester):
def test_password_reset_request(self): def test_password_reset_request(self):
self.executeBadData({'email': 'invalid@mail.ru'}) self.executeBadData({'email': 'invalid@mail.ru'})
self.executeOK({'email': self.user.email}) self.executeOK({'email': self.user.email})
# TODO: check if mail server actually sent email and if reset procedure works
class TestSignupAPIView(EndpointTester): class TestSignupAPIView(EndpointTester):

View File

@ -1,7 +1,5 @@
import { FolderTree } from './FolderTree'; import { FolderTree } from './FolderTree';
// TODO: test FolderNode and FolderTree exhaustively
describe('Testing Tree construction', () => { describe('Testing Tree construction', () => {
test('empty Tree should be empty', () => { test('empty Tree should be empty', () => {
const tree = new FolderTree(); const tree = new FolderTree();