Implement backend for permission management

This commit is contained in:
IRBorisov 2024-05-26 00:46:58 +03:00
parent 4357fbf83f
commit 467cd3dcc9
18 changed files with 392 additions and 154 deletions

View File

@ -1,6 +1,7 @@
''' Models: Editor. ''' ''' Models: Editor. '''
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from django.db import transaction
from django.db.models import CASCADE, DateTimeField, ForeignKey, Model from django.db.models import CASCADE, DateTimeField, ForeignKey, Model
from apps.users.models import User from apps.users.models import User
@ -37,18 +38,34 @@ class Editor(Model):
return f'{self.item}: {self.editor}' return f'{self.item}: {self.editor}'
@staticmethod @staticmethod
def add(user: User, item: 'LibraryItem') -> bool: def add(item: 'LibraryItem', user: User) -> bool:
''' Add Editor for item. ''' ''' Add Editor for item. '''
if Editor.objects.filter(editor=user, item=item).exists(): if Editor.objects.filter(item=item, editor=user).exists():
return False return False
Editor.objects.create(editor=user, item=item) Editor.objects.create(item=item, editor=user)
return True return True
@staticmethod @staticmethod
def remove(user: User, item: 'LibraryItem') -> bool: def remove(item: 'LibraryItem', user: User) -> bool:
''' Remove Editor. ''' ''' Remove Editor. '''
editor = Editor.objects.filter(editor=user, item=item) editor = Editor.objects.filter(item=item, editor=user)
if not editor.exists(): if not editor.exists():
return False return False
editor.delete() editor.delete()
return True return True
@staticmethod
@transaction.atomic
def set(item: 'LibraryItem', users: list[User]):
''' Set editors for item. '''
processed: list[User] = []
for editor_item in Editor.objects.filter(item=item):
if not editor_item.editor in users:
editor_item.delete()
else:
processed.append(editor_item.editor)
for user in users:
if not user in processed:
processed.append(user)
Editor.objects.create(item=item, editor=user)

View File

@ -87,7 +87,7 @@ class LibraryItem(Model):
def editors(self) -> list[Editor]: def editors(self) -> list[Editor]:
''' Get all Editors of this item. ''' ''' Get all Editors of this item. '''
return [item.editor for item in Editor.objects.filter(item=self.pk).order_by('-time_create')] return [item.editor for item in Editor.objects.filter(item=self.pk)]
@transaction.atomic @transaction.atomic
def save(self, *args, **kwargs): def save(self, *args, **kwargs):

View File

@ -22,6 +22,8 @@ from .data_access import (
LibraryItemSerializer, LibraryItemSerializer,
RSFormParseSerializer, RSFormParseSerializer,
RSFormSerializer, RSFormSerializer,
UsersListSerializer,
UserTargetSerializer,
VersionCreateSerializer, VersionCreateSerializer,
VersionSerializer VersionSerializer
) )

View File

@ -272,6 +272,16 @@ class CstTargetSerializer(serializers.Serializer):
return attrs return attrs
class UserTargetSerializer(serializers.Serializer):
''' Serializer: Target single User. '''
user = PKField(many=False, queryset=User.objects.all())
class UsersListSerializer(serializers.Serializer):
''' Serializer: List of Users. '''
users = PKField(many=True, queryset=User.objects.all())
class CstRenameSerializer(serializers.Serializer): class CstRenameSerializer(serializers.Serializer):
''' Serializer: Constituenta renaming. ''' ''' Serializer: Constituenta renaming. '''
target = PKField(many=False, queryset=Constituenta.objects.all()) target = PKField(many=False, queryset=Constituenta.objects.all())

View File

@ -23,7 +23,7 @@ class TestConstituenta(TestCase):
def test_url(self): def test_url(self):
testStr = 'X1' testStr = 'X1'
cst = Constituenta.objects.create(alias=testStr, schema=self.schema1, order=1, convention='Test') cst = Constituenta.objects.create(alias=testStr, schema=self.schema1, order=1, convention='Test')
self.assertEqual(cst.get_absolute_url(), f'/api/constituents/{cst.id}') self.assertEqual(cst.get_absolute_url(), f'/api/constituents/{cst.pk}')
def test_order_not_null(self): def test_order_not_null(self):

View File

@ -33,14 +33,14 @@ class TestEditor(TestCase):
def test_add_editor(self): def test_add_editor(self):
self.assertTrue(Editor.add(self.user1, self.item)) self.assertTrue(Editor.add(self.item, self.user1))
self.assertEqual(len(self.item.editors()), 1) self.assertEqual(len(self.item.editors()), 1)
self.assertTrue(self.user1 in self.item.editors()) self.assertTrue(self.user1 in self.item.editors())
self.assertFalse(Editor.add(self.user1, self.item)) self.assertFalse(Editor.add(self.item, self.user1))
self.assertEqual(len(self.item.editors()), 1) self.assertEqual(len(self.item.editors()), 1)
self.assertTrue(Editor.add(self.user2, self.item)) self.assertTrue(Editor.add(self.item, self.user2))
self.assertEqual(len(self.item.editors()), 2) self.assertEqual(len(self.item.editors()), 2)
self.assertTrue(self.user1 in self.item.editors()) self.assertTrue(self.user1 in self.item.editors())
self.assertTrue(self.user2 in self.item.editors()) self.assertTrue(self.user2 in self.item.editors())
@ -50,13 +50,27 @@ class TestEditor(TestCase):
def test_remove_editor(self): def test_remove_editor(self):
self.assertFalse(Editor.remove(self.user1, self.item)) self.assertFalse(Editor.remove(self.item, self.user1))
Editor.add(self.user1, self.item) Editor.add(self.item, self.user1)
Editor.add(self.user2, self.item) Editor.add(self.item, self.user2)
self.assertEqual(len(self.item.editors()), 2) self.assertEqual(len(self.item.editors()), 2)
self.assertTrue(Editor.remove(self.user1, self.item)) self.assertTrue(Editor.remove(self.item, self.user1))
self.assertEqual(len(self.item.editors()), 1) self.assertEqual(len(self.item.editors()), 1)
self.assertTrue(self.user2 in self.item.editors()) self.assertTrue(self.user2 in self.item.editors())
self.assertFalse(Editor.remove(self.user1, self.item)) self.assertFalse(Editor.remove(self.item, self.user1))
def test_set_editors(self):
Editor.set(self.item, [self.user1])
self.assertEqual(self.item.editors(), [self.user1])
Editor.set(self.item, [self.user1, self.user1])
self.assertEqual(self.item.editors(), [self.user1])
Editor.set(self.item, [])
self.assertEqual(self.item.editors(), [])
Editor.set(self.item, [self.user1, self.user2])
self.assertEqual(set(self.item.editors()), set([self.user1, self.user2]))

View File

@ -31,7 +31,7 @@ class TestLibraryItem(TestCase):
owner=self.user1, owner=self.user1,
alias='КС1' alias='КС1'
) )
self.assertEqual(item.get_absolute_url(), f'/api/library/{item.id}') self.assertEqual(item.get_absolute_url(), f'/api/library/{item.pk}')
def test_create_default(self): def test_create_default(self):

View File

@ -350,7 +350,7 @@ class TestRSForm(TestCase):
x1.term_resolved = 'слон' x1.term_resolved = 'слон'
x1.save() x1.save()
self.schema.on_term_change([x1.id]) self.schema.on_term_change([x1.pk])
x1.refresh_from_db() x1.refresh_from_db()
x2.refresh_from_db() x2.refresh_from_db()
x3.refresh_from_db() x3.refresh_from_db()

View File

@ -1,6 +1,6 @@
''' Utils: base tester class for endpoints. ''' ''' Utils: base tester class for endpoints. '''
from rest_framework.test import APITestCase, APIRequestFactory, APIClient
from rest_framework import status from rest_framework import status
from rest_framework.test import APIClient, APIRequestFactory, APITestCase
from apps.users.models import User from apps.users.models import User
@ -23,6 +23,7 @@ def decl_endpoint(endpoint: str, method: str):
class EndpointTester(APITestCase): class EndpointTester(APITestCase):
''' Abstract base class for Testing endpoints. ''' ''' Abstract base class for Testing endpoints. '''
def setUp(self): def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
self.user = User.objects.create(username='UserTest') self.user = User.objects.create(username='UserTest')
@ -129,15 +130,15 @@ def _resolve_url(url: str, **kwargs) -> str:
pos_end = url.find('}', pos_start) pos_end = url.find('}', pos_start)
if pos_end == -1: if pos_end == -1:
break break
name = url[(pos_start + 1) : pos_end] name = url[(pos_start + 1): pos_end]
arg_names.add(name) arg_names.add(name)
if not name in kwargs: if not name in kwargs:
raise KeyError(f'Missing argument: {name} | Mask: {url}') raise KeyError(f'Missing argument: {name} | Mask: {url}')
output += url[pos_input : pos_start] output += url[pos_input: pos_start]
output += str(kwargs[name]) output += str(kwargs[name])
pos_input = pos_end + 1 pos_input = pos_end + 1
if pos_input < len(url): if pos_input < len(url):
output += url[pos_input : len(url)] output += url[pos_input: len(url)]
for (key, _) in kwargs.items(): for (key, _) in kwargs.items():
if key not in arg_names: if key not in arg_names:
raise KeyError(f'Unused argument: {name} | Mask: {url}') raise KeyError(f'Unused argument: {name} | Mask: {url}')

View File

@ -1,13 +1,13 @@
''' Testing views ''' ''' Testing views '''
from cctext import split_grams
from rest_framework import status from rest_framework import status
from cctext import split_grams from .EndpointTester import EndpointTester, decl_endpoint
from .EndpointTester import decl_endpoint, EndpointTester
class TestNaturalLanguageViews(EndpointTester): class TestNaturalLanguageViews(EndpointTester):
''' Test natural language endpoints. ''' ''' Test natural language endpoints. '''
def _assert_tags(self, actual: str, expected: str): def _assert_tags(self, actual: str, expected: str):
self.assertEqual(set(split_grams(actual)), set(split_grams(expected))) self.assertEqual(set(split_grams(actual)), set(split_grams(expected)))

View File

@ -1,13 +1,14 @@
''' Testing API: Constituents. ''' ''' Testing API: Constituents. '''
from rest_framework import status from rest_framework import status
from apps.rsform.models import RSForm, Constituenta, CstType from apps.rsform.models import Constituenta, CstType, RSForm
from .EndpointTester import decl_endpoint, EndpointTester from .EndpointTester import EndpointTester, decl_endpoint
class TestConstituentaAPI(EndpointTester): class TestConstituentaAPI(EndpointTester):
''' Testing Constituenta view. ''' ''' Testing Constituenta view. '''
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.rsform_owned = RSForm.create(title='Test', alias='T1', owner=self.user) self.rsform_owned = RSForm.create(title='Test', alias='T1', owner=self.user)
@ -20,7 +21,7 @@ class TestConstituentaAPI(EndpointTester):
convention='Test', convention='Test',
term_raw='Test1', term_raw='Test1',
term_resolved='Test1R', term_resolved='Test1R',
term_forms=[{'text':'form1', 'tags':'sing,datv'}]) term_forms=[{'text': 'form1', 'tags': 'sing,datv'}])
self.cst2 = Constituenta.objects.create( self.cst2 = Constituenta.objects.create(
alias='X2', alias='X2',
cst_type=CstType.BASE, cst_type=CstType.BASE,
@ -45,7 +46,7 @@ class TestConstituentaAPI(EndpointTester):
@decl_endpoint('/api/constituents/{item}', method='get') @decl_endpoint('/api/constituents/{item}', method='get')
def test_retrieve(self): def test_retrieve(self):
self.assertNotFound(item=self.invalid_cst) self.assertNotFound(item=self.invalid_cst)
response = self.execute(item=self.cst1.id) response = self.execute(item=self.cst1.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['alias'], self.cst1.alias) self.assertEqual(response.data['alias'], self.cst1.alias)
self.assertEqual(response.data['convention'], self.cst1.convention) self.assertEqual(response.data['convention'], self.cst1.convention)
@ -54,19 +55,19 @@ class TestConstituentaAPI(EndpointTester):
@decl_endpoint('/api/constituents/{item}', method='patch') @decl_endpoint('/api/constituents/{item}', method='patch')
def test_partial_update(self): def test_partial_update(self):
data = {'convention': 'tt'} data = {'convention': 'tt'}
self.assertForbidden(data, item=self.cst2.id) self.assertForbidden(data, item=self.cst2.pk)
self.logout() self.logout()
self.assertForbidden(data, item=self.cst1.id) self.assertForbidden(data, item=self.cst1.pk)
self.login() self.login()
response = self.execute(data, item=self.cst1.id) response = self.execute(data, item=self.cst1.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.cst1.refresh_from_db() self.cst1.refresh_from_db()
self.assertEqual(response.data['convention'], 'tt') self.assertEqual(response.data['convention'], 'tt')
self.assertEqual(self.cst1.convention, 'tt') self.assertEqual(self.cst1.convention, 'tt')
self.assertOK(data, item=self.cst1.id) self.assertOK(data, item=self.cst1.pk)
@decl_endpoint('/api/constituents/{item}', method='patch') @decl_endpoint('/api/constituents/{item}', method='patch')
@ -75,7 +76,7 @@ class TestConstituentaAPI(EndpointTester):
'term_raw': 'New term', 'term_raw': 'New term',
'definition_raw': 'New def' 'definition_raw': 'New def'
} }
response = self.execute(data, item=self.cst3.id) response = self.execute(data, item=self.cst3.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.cst3.refresh_from_db() self.cst3.refresh_from_db()
self.assertEqual(response.data['term_resolved'], 'New term') self.assertEqual(response.data['term_resolved'], 'New term')
@ -90,7 +91,7 @@ class TestConstituentaAPI(EndpointTester):
'term_raw': '@{X1|nomn,sing}', 'term_raw': '@{X1|nomn,sing}',
'definition_raw': '@{X1|nomn,sing} @{X1|sing,datv}' 'definition_raw': '@{X1|nomn,sing} @{X1|sing,datv}'
} }
response = self.execute(data, item=self.cst3.id) response = self.execute(data, item=self.cst3.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.cst3.refresh_from_db() self.cst3.refresh_from_db()
self.assertEqual(self.cst3.term_resolved, self.cst1.term_resolved) self.assertEqual(self.cst3.term_resolved, self.cst1.term_resolved)
@ -102,7 +103,7 @@ class TestConstituentaAPI(EndpointTester):
@decl_endpoint('/api/constituents/{item}', method='patch') @decl_endpoint('/api/constituents/{item}', method='patch')
def test_readonly_cst_fields(self): def test_readonly_cst_fields(self):
data = {'alias': 'X33', 'order': 10} data = {'alias': 'X33', 'order': 10}
response = self.execute(data, item=self.cst1.id) response = self.execute(data, item=self.cst1.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['alias'], 'X1') self.assertEqual(response.data['alias'], 'X1')
self.assertEqual(response.data['alias'], self.cst1.alias) self.assertEqual(response.data['alias'], self.cst1.alias)

View File

@ -1,7 +1,14 @@
''' Testing API: Library. ''' ''' Testing API: Library. '''
from rest_framework import status from rest_framework import status
from apps.rsform.models import LibraryItem, LibraryItemType, LibraryTemplate, RSForm, Subscription from apps.rsform.models import (
Editor,
LibraryItem,
LibraryItemType,
LibraryTemplate,
RSForm,
Subscription
)
from apps.users.models import User from apps.users.models import User
from ..testing_utils import response_contains from ..testing_utils import response_contains
@ -13,6 +20,7 @@ class TestLibraryViewset(EndpointTester):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.user2 = User.objects.create(username='UserTest2')
self.owned = LibraryItem.objects.create( self.owned = LibraryItem.objects.create(
item_type=LibraryItemType.RSFORM, item_type=LibraryItemType.RSFORM,
title='Test', title='Test',
@ -31,6 +39,7 @@ class TestLibraryViewset(EndpointTester):
alias='T3', alias='T3',
is_common=True is_common=True
) )
self.invalid_user = 1337 + self.user2.pk
self.invalid_item = 1337 + self.common.pk self.invalid_item = 1337 + self.common.pk
@ -40,7 +49,7 @@ class TestLibraryViewset(EndpointTester):
response = self.post(data=data) response = self.post(data=data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['title'], 'Title') self.assertEqual(response.data['title'], 'Title')
self.assertEqual(response.data['owner'], self.user.id) self.assertEqual(response.data['owner'], self.user.pk)
self.logout() self.logout()
data = {'title': 'Title2'} data = {'title': 'Title2'}
@ -49,25 +58,132 @@ class TestLibraryViewset(EndpointTester):
@decl_endpoint('/api/library/{item}', method='patch') @decl_endpoint('/api/library/{item}', method='patch')
def test_update(self): def test_update(self):
data = {'id': self.unowned.id, 'title': 'New title'} data = {'id': self.unowned.pk, 'title': 'New title'}
self.assertNotFound(data, item=self.invalid_item) self.assertNotFound(data, item=self.invalid_item)
self.assertForbidden(data, item=self.unowned.id) self.assertForbidden(data, item=self.unowned.pk)
data = {'id': self.owned.id, 'title': 'New title'} data = {'id': self.owned.pk, 'title': 'New title'}
response = self.execute(data, item=self.owned.id) response = self.execute(data, item=self.owned.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['title'], 'New title') self.assertEqual(response.data['title'], 'New title')
self.assertEqual(response.data['alias'], self.owned.alias) self.assertEqual(response.data['alias'], self.owned.alias)
@decl_endpoint('/api/library/{item}/set-owner', method='patch')
def test_set_owner(self):
time_update = self.owned.time_update
data = {'user': self.user.pk}
self.assertNotFound(data, item=self.invalid_item)
self.assertForbidden(data, item=self.unowned.pk)
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.owner, self.user)
data = {'user': self.user2.pk}
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.owner, self.user2)
self.assertEqual(self.owned.time_update, time_update)
self.assertForbidden(data, item=self.owned.pk)
self.toggle_admin(True)
data = {'user': self.user.pk}
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.owner, self.user)
@decl_endpoint('/api/library/{item}/editors-add', method='patch')
def test_add_editor(self):
time_update = self.owned.time_update
data = {'user': self.invalid_user}
self.assertBadData(data, item=self.owned.pk)
data = {'user': self.user.pk}
self.assertNotFound(data, item=self.invalid_item)
self.assertForbidden(data, item=self.unowned.pk)
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.time_update, time_update)
self.assertEqual(self.owned.editors(), [self.user])
self.assertOK(data)
self.assertEqual(self.owned.editors(), [self.user])
data = {'user': self.user2.pk}
self.assertOK(data)
self.assertEqual(set(self.owned.editors()), set([self.user, self.user2]))
@decl_endpoint('/api/library/{item}/editors-remove', method='patch')
def test_remove_editor(self):
time_update = self.owned.time_update
data = {'user': self.invalid_user}
self.assertBadData(data, item=self.owned.pk)
data = {'user': self.user.pk}
self.assertNotFound(data, item=self.invalid_item)
self.assertForbidden(data, item=self.unowned.pk)
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.time_update, time_update)
self.assertEqual(self.owned.editors(), [])
Editor.add(item=self.owned, user=self.user)
self.assertOK(data)
self.assertEqual(self.owned.editors(), [])
Editor.add(item=self.owned, user=self.user)
Editor.add(item=self.owned, user=self.user2)
data = {'user': self.user2.pk}
self.assertOK(data)
self.assertEqual(self.owned.editors(), [self.user])
@decl_endpoint('/api/library/{item}/editors-set', method='patch')
def test_set_editors(self):
time_update = self.owned.time_update
data = {'users': [self.invalid_user]}
self.assertBadData(data, item=self.owned.pk)
data = {'users': [self.user.pk]}
self.assertNotFound(data, item=self.invalid_item)
self.assertForbidden(data, item=self.unowned.pk)
self.assertOK(data, item=self.owned.pk)
self.owned.refresh_from_db()
self.assertEqual(self.owned.time_update, time_update)
self.assertEqual(self.owned.editors(), [self.user])
self.assertOK(data)
self.assertEqual(self.owned.editors(), [self.user])
data = {'users': [self.user2.pk]}
self.assertOK(data)
self.assertEqual(self.owned.editors(), [self.user2])
data = {'users': []}
self.assertOK(data)
self.assertEqual(self.owned.editors(), [])
data = {'users': [self.user2.pk, self.user.pk]}
self.assertOK(data)
self.assertEqual(set(self.owned.editors()), set([self.user2, self.user]))
@decl_endpoint('/api/library/{item}', method='delete') @decl_endpoint('/api/library/{item}', method='delete')
def test_destroy(self): def test_destroy(self):
response = self.execute(item=self.owned.id) response = self.execute(item=self.owned.pk)
self.assertTrue(response.status_code in [status.HTTP_202_ACCEPTED, status.HTTP_204_NO_CONTENT]) self.assertTrue(response.status_code in [status.HTTP_202_ACCEPTED, status.HTTP_204_NO_CONTENT])
self.assertForbidden(item=self.unowned.id) self.assertForbidden(item=self.unowned.pk)
self.toggle_admin(True) self.toggle_admin(True)
response = self.execute(item=self.unowned.id) response = self.execute(item=self.unowned.pk)
self.assertTrue(response.status_code in [status.HTTP_202_ACCEPTED, status.HTTP_204_NO_CONTENT]) self.assertTrue(response.status_code in [status.HTTP_202_ACCEPTED, status.HTTP_204_NO_CONTENT])
@ -108,10 +224,9 @@ class TestLibraryViewset(EndpointTester):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(response_contains(response, self.unowned)) self.assertFalse(response_contains(response, self.unowned))
user2 = User.objects.create(username='UserTest2')
Subscription.subscribe(user=self.user, item=self.unowned) Subscription.subscribe(user=self.user, item=self.unowned)
Subscription.subscribe(user=user2, item=self.unowned) Subscription.subscribe(user=self.user2, item=self.unowned)
Subscription.subscribe(user=user2, item=self.owned) Subscription.subscribe(user=self.user2, item=self.owned)
response = self.execute() response = self.execute()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -122,20 +237,20 @@ class TestLibraryViewset(EndpointTester):
@decl_endpoint('/api/library/{item}/subscribe', method='post') @decl_endpoint('/api/library/{item}/subscribe', method='post')
def test_subscriptions(self): def test_subscriptions(self):
self.assertNotFound(item=self.invalid_item) self.assertNotFound(item=self.invalid_item)
response = self.client.delete(f'/api/library/{self.unowned.id}/unsubscribe') response = self.client.delete(f'/api/library/{self.unowned.pk}/unsubscribe')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(self.user in self.unowned.subscribers()) self.assertFalse(self.user in self.unowned.subscribers())
response = self.execute(item=self.unowned.id) response = self.execute(item=self.unowned.pk)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(self.user in self.unowned.subscribers()) self.assertTrue(self.user in self.unowned.subscribers())
response = self.execute(item=self.unowned.id) response = self.execute(item=self.unowned.pk)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertTrue(self.user in self.unowned.subscribers()) self.assertTrue(self.user in self.unowned.subscribers())
response = self.client.delete(f'/api/library/{self.unowned.id}/unsubscribe') response = self.client.delete(f'/api/library/{self.unowned.pk}/unsubscribe')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(self.user in self.unowned.subscribers()) self.assertFalse(self.user in self.unowned.subscribers())
@ -170,10 +285,10 @@ class TestLibraryViewset(EndpointTester):
data = {'title': 'Title1337'} data = {'title': 'Title1337'}
self.assertNotFound(data, item=self.invalid_item) self.assertNotFound(data, item=self.invalid_item)
self.assertCreated(data, item=self.unowned.id) self.assertCreated(data, item=self.unowned.pk)
data = {'title': 'Title1338'} data = {'title': 'Title1338'}
response = self.execute(data, item=self.owned.id) response = self.execute(data, item=self.owned.pk)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['title'], data['title']) self.assertEqual(response.data['title'], data['title'])
self.assertEqual(len(response.data['items']), 2) self.assertEqual(len(response.data['items']), 2)
@ -184,13 +299,13 @@ class TestLibraryViewset(EndpointTester):
self.assertEqual(response.data['items'][1]['term_resolved'], d2.term_resolved) self.assertEqual(response.data['items'][1]['term_resolved'], d2.term_resolved)
data = {'title': 'Title1340', 'items': []} data = {'title': 'Title1340', 'items': []}
response = self.execute(data, item=self.owned.id) response = self.execute(data, item=self.owned.pk)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['title'], data['title']) self.assertEqual(response.data['title'], data['title'])
self.assertEqual(len(response.data['items']), 0) self.assertEqual(len(response.data['items']), 0)
data = {'title': 'Title1341', 'items': [x12.pk]} data = {'title': 'Title1341', 'items': [x12.pk]}
response = self.execute(data, item=self.owned.id) response = self.execute(data, item=self.owned.pk)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['title'], data['title']) self.assertEqual(response.data['title'], data['title'])
self.assertEqual(len(response.data['items']), 1) self.assertEqual(len(response.data['items']), 1)

View File

@ -1,12 +1,9 @@
''' Testing API: Operations. ''' ''' Testing API: Operations. '''
from rest_framework import status from rest_framework import status
from .EndpointTester import decl_endpoint, EndpointTester
from apps.rsform.models import ( from apps.rsform.models import Constituenta, CstType, RSForm
RSForm,
Constituenta, from .EndpointTester import EndpointTester, decl_endpoint
CstType
)
class TestInlineSynthesis(EndpointTester): class TestInlineSynthesis(EndpointTester):
@ -24,8 +21,8 @@ class TestInlineSynthesis(EndpointTester):
def test_inline_synthesis_inputs(self): def test_inline_synthesis_inputs(self):
invalid_id = 1338 invalid_id = 1338
data = { data = {
'receiver': self.unowned.item.id, 'receiver': self.unowned.item.pk,
'source': self.schema1.item.id, 'source': self.schema1.item.pk,
'items': [], 'items': [],
'substitutions': [] 'substitutions': []
} }
@ -34,11 +31,11 @@ class TestInlineSynthesis(EndpointTester):
data['receiver'] = invalid_id data['receiver'] = invalid_id
self.assertBadData(data) self.assertBadData(data)
data['receiver'] = self.schema1.item.id data['receiver'] = self.schema1.item.pk
data['source'] = invalid_id data['source'] = invalid_id
self.assertBadData(data) self.assertBadData(data)
data['source'] = self.schema1.item.id data['source'] = self.schema1.item.pk
self.assertOK(data) self.assertOK(data)
data['items'] = [invalid_id] data['items'] = [invalid_id]
@ -57,8 +54,8 @@ class TestInlineSynthesis(EndpointTester):
ks2_a1 = self.schema2.insert_new('A1', definition_formal='1=1') # -> not included in items ks2_a1 = self.schema2.insert_new('A1', definition_formal='1=1') # -> not included in items
data = { data = {
'receiver': self.schema1.item.id, 'receiver': self.schema1.item.pk,
'source': self.schema2.item.id, 'source': self.schema2.item.pk,
'items': [ks2_x1.pk, ks2_x2.pk, ks2_s1.pk, ks2_d1.pk], 'items': [ks2_x1.pk, ks2_x2.pk, ks2_s1.pk, ks2_d1.pk],
'substitutions': [ 'substitutions': [
{ {

View File

@ -1,31 +1,26 @@
''' Testing API: RSForms. ''' ''' Testing API: RSForms. '''
import os
import io import io
import os
from zipfile import ZipFile from zipfile import ZipFile
from rest_framework import status
from apps.rsform.models import (
RSForm,
Constituenta,
CstType,
LibraryItem,
LibraryItemType
)
from cctext import ReferenceType from cctext import ReferenceType
from ..testing_utils import response_contains from rest_framework import status
from .EndpointTester import decl_endpoint, EndpointTester from apps.rsform.models import Constituenta, CstType, LibraryItem, LibraryItemType, RSForm
from ..testing_utils import response_contains
from .EndpointTester import EndpointTester, decl_endpoint
class TestRSFormViewset(EndpointTester): class TestRSFormViewset(EndpointTester):
''' Testing RSForm view. ''' ''' Testing RSForm view. '''
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.schema = RSForm.create(title='Test', alias='T1', owner=self.user) self.schema = RSForm.create(title='Test', alias='T1', owner=self.user)
self.schema_id = self.schema.item.id self.schema_id = self.schema.item.pk
self.unowned = RSForm.create(title='Test2', alias='T2') self.unowned = RSForm.create(title='Test2', alias='T2')
self.unowned_id = self.unowned.item.id self.unowned_id = self.unowned.item.pk
@decl_endpoint('/api/rsforms/create-detailed', method='post') @decl_endpoint('/api/rsforms/create-detailed', method='post')
@ -75,7 +70,7 @@ class TestRSFormViewset(EndpointTester):
def test_contents(self): def test_contents(self):
schema = RSForm.create(title='Title1') schema = RSForm.create(title='Title1')
schema.insert_new('X1') schema.insert_new('X1')
self.assertOK(item=schema.item.id) self.assertOK(item=schema.item.pk)
@decl_endpoint('/api/rsforms/{item}/details', method='get') @decl_endpoint('/api/rsforms/{item}/details', method='get')
@ -84,23 +79,23 @@ class TestRSFormViewset(EndpointTester):
x1 = schema.insert_new( x1 = schema.insert_new(
alias='X1', alias='X1',
term_raw='человек', term_raw='человек',
term_resolved = 'человек' term_resolved='человек'
) )
x2 = schema.insert_new( x2 = schema.insert_new(
alias='X2', alias='X2',
term_raw='@{X1|plur}', term_raw='@{X1|plur}',
term_resolved = 'люди' term_resolved='люди'
) )
response = self.execute(item=schema.item.id) response = self.execute(item=schema.item.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['title'], 'Test') self.assertEqual(response.data['title'], 'Test')
self.assertEqual(len(response.data['items']), 2) self.assertEqual(len(response.data['items']), 2)
self.assertEqual(response.data['items'][0]['id'], x1.id) self.assertEqual(response.data['items'][0]['id'], x1.pk)
self.assertEqual(response.data['items'][0]['parse']['status'], 'verified') self.assertEqual(response.data['items'][0]['parse']['status'], 'verified')
self.assertEqual(response.data['items'][0]['term_raw'], x1.term_raw) self.assertEqual(response.data['items'][0]['term_raw'], x1.term_raw)
self.assertEqual(response.data['items'][0]['term_resolved'], x1.term_resolved) self.assertEqual(response.data['items'][0]['term_resolved'], x1.term_resolved)
self.assertEqual(response.data['items'][1]['id'], x2.id) self.assertEqual(response.data['items'][1]['id'], x2.pk)
self.assertEqual(response.data['items'][1]['term_raw'], x2.term_raw) self.assertEqual(response.data['items'][1]['term_raw'], x2.term_raw)
self.assertEqual(response.data['items'][1]['term_resolved'], x2.term_resolved) self.assertEqual(response.data['items'][1]['term_resolved'], x2.term_resolved)
self.assertEqual(response.data['subscribers'], [self.user.pk]) self.assertEqual(response.data['subscribers'], [self.user.pk])
@ -112,7 +107,7 @@ class TestRSFormViewset(EndpointTester):
schema = RSForm.create(title='Test') schema = RSForm.create(title='Test')
schema.insert_new('X1') schema.insert_new('X1')
data = {'expression': 'X1=X1'} data = {'expression': 'X1=X1'}
response = self.execute(data, item=schema.item.id) response = self.execute(data, item=schema.item.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['parseResult'], True) self.assertEqual(response.data['parseResult'], True)
self.assertEqual(response.data['syntax'], 'math') self.assertEqual(response.data['syntax'], 'math')
@ -132,7 +127,7 @@ class TestRSFormViewset(EndpointTester):
) )
data = {'text': '@{1|редкий} @{X1|plur,datv}'} data = {'text': '@{1|редкий} @{X1|plur,datv}'}
response = self.execute(data, item=schema.item.id) response = self.execute(data, item=schema.item.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['input'], '@{1|редкий} @{X1|plur,datv}') self.assertEqual(response.data['input'], '@{1|редкий} @{X1|plur,datv}')
self.assertEqual(response.data['output'], 'редким синим слонам') self.assertEqual(response.data['output'], 'редким синим слонам')
@ -170,7 +165,7 @@ class TestRSFormViewset(EndpointTester):
def test_export_trs(self): def test_export_trs(self):
schema = RSForm.create(title='Test') schema = RSForm.create(title='Test')
schema.insert_new('X1') schema.insert_new('X1')
response = self.execute(item=schema.item.id) response = self.execute(item=schema.item.pk)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.headers['Content-Disposition'], 'attachment; filename=Schema.trs') self.assertEqual(response.headers['Content-Disposition'], 'attachment; filename=Schema.trs')
with io.BytesIO(response.content) as stream: with io.BytesIO(response.content) as stream:
@ -196,9 +191,9 @@ class TestRSFormViewset(EndpointTester):
data = { data = {
'alias': 'X4', 'alias': 'X4',
'cst_type': CstType.BASE, 'cst_type': CstType.BASE,
'insert_after': x2.id, 'insert_after': x2.pk,
'term_raw': 'test', 'term_raw': 'test',
'term_forms': [{'text':'form1', 'tags':'sing,datv'}] 'term_forms': [{'text': 'form1', 'tags': 'sing,datv'}]
} }
response = self.execute(data, item=self.schema_id) response = self.execute(data, item=self.schema_id)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@ -216,7 +211,7 @@ class TestRSFormViewset(EndpointTester):
convention='Test', convention='Test',
term_raw='Test1', term_raw='Test1',
term_resolved='Test1', term_resolved='Test1',
term_forms=[{'text':'form1', 'tags':'sing,datv'}] term_forms=[{'text': 'form1', 'tags': 'sing,datv'}]
) )
x2_2 = self.unowned.insert_new('X2') x2_2 = self.unowned.insert_new('X2')
x3 = self.schema.insert_new( x3 = self.schema.insert_new(
@ -239,8 +234,8 @@ class TestRSFormViewset(EndpointTester):
d1 = self.schema.insert_new( d1 = self.schema.insert_new(
alias='D1', alias='D1',
term_raw = '@{X1|plur}', term_raw='@{X1|plur}',
definition_formal = 'X1' definition_formal='X1'
) )
self.assertEqual(x1.order, 1) self.assertEqual(x1.order, 1)
self.assertEqual(x1.alias, 'X1') self.assertEqual(x1.alias, 'X1')
@ -266,7 +261,7 @@ class TestRSFormViewset(EndpointTester):
alias='X1', alias='X1',
term_raw='Test1', term_raw='Test1',
term_resolved='Test1', term_resolved='Test1',
term_forms=[{'text':'form1', 'tags':'sing,datv'}] term_forms=[{'text': 'form1', 'tags': 'sing,datv'}]
) )
x2 = self.schema.insert_new( x2 = self.schema.insert_new(
alias='X2', alias='X2',
@ -379,7 +374,7 @@ class TestRSFormViewset(EndpointTester):
x1 = self.schema.insert_new('X1') x1 = self.schema.insert_new('X1')
x2 = self.schema.insert_new('X2') x2 = self.schema.insert_new('X2')
data = {'items': [x1.id]} data = {'items': [x1.pk]}
response = self.execute(data) response = self.execute(data)
x2.refresh_from_db() x2.refresh_from_db()
self.schema.item.refresh_from_db() self.schema.item.refresh_from_db()
@ -390,7 +385,7 @@ class TestRSFormViewset(EndpointTester):
self.assertEqual(x2.order, 1) self.assertEqual(x2.order, 1)
x3 = self.unowned.insert_new('X1') x3 = self.unowned.insert_new('X1')
data = {'items': [x3.id]} data = {'items': [x3.pk]}
self.assertBadData(data, item=self.schema_id) self.assertBadData(data, item=self.schema_id)
@ -404,7 +399,7 @@ class TestRSFormViewset(EndpointTester):
x1 = self.schema.insert_new('X1') x1 = self.schema.insert_new('X1')
x2 = self.schema.insert_new('X2') x2 = self.schema.insert_new('X2')
data = {'items': [x2.id], 'move_to': 1} data = {'items': [x2.pk], 'move_to': 1}
response = self.execute(data) response = self.execute(data)
x1.refresh_from_db() x1.refresh_from_db()
x2.refresh_from_db() x2.refresh_from_db()
@ -414,7 +409,7 @@ class TestRSFormViewset(EndpointTester):
self.assertEqual(x2.order, 1) self.assertEqual(x2.order, 1)
x3 = self.unowned.insert_new('X1') x3 = self.unowned.insert_new('X1')
data = {'items': [x3.id], 'move_to': 1} data = {'items': [x3.pk], 'move_to': 1}
self.assertBadData(data) self.assertBadData(data)
@ -460,7 +455,7 @@ class TestRSFormViewset(EndpointTester):
self.assertEqual(self.schema.item.title, 'Test11') self.assertEqual(self.schema.item.title, 'Test11')
self.assertEqual(len(response.data['items']), 25) self.assertEqual(len(response.data['items']), 25)
self.assertEqual(self.schema.constituents().count(), 25) self.assertEqual(self.schema.constituents().count(), 25)
self.assertFalse(Constituenta.objects.filter(pk=x1.id).exists()) self.assertFalse(Constituenta.objects.filter(pk=x1.pk).exists())
@decl_endpoint('/api/rsforms/{item}/cst-produce-structure', method='patch') @decl_endpoint('/api/rsforms/{item}/cst-produce-structure', method='patch')
@ -490,11 +485,11 @@ class TestRSFormViewset(EndpointTester):
invalid_id = f1.pk + 1337 invalid_id = f1.pk + 1337
self.assertBadData({'target': invalid_id}) self.assertBadData({'target': invalid_id})
self.assertBadData({'target': x1.id}) self.assertBadData({'target': x1.pk})
self.assertBadData({'target': s2.id}) self.assertBadData({'target': s2.pk})
# Testing simple structure # Testing simple structure
response = self.execute({'target': s1.id}) response = self.execute({'target': s1.pk})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
result = response.data['schema'] result = response.data['schema']
items = [item for item in result['items'] if item['id'] in response.data['cst_list']] items = [item for item in result['items'] if item['id'] in response.data['cst_list']]
@ -506,7 +501,7 @@ class TestRSFormViewset(EndpointTester):
# Testing complex structure # Testing complex structure
s3.refresh_from_db() s3.refresh_from_db()
response = self.execute({'target': s3.id}) response = self.execute({'target': s3.pk})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
result = response.data['schema'] result = response.data['schema']
items = [item for item in result['items'] if item['id'] in response.data['cst_list']] items = [item for item in result['items'] if item['id'] in response.data['cst_list']]
@ -516,7 +511,7 @@ class TestRSFormViewset(EndpointTester):
# Testing function # Testing function
f1.refresh_from_db() f1.refresh_from_db()
response = self.execute({'target': f1.id}) response = self.execute({'target': f1.pk})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
result = response.data['schema'] result = response.data['schema']
items = [item for item in result['items'] if item['id'] in response.data['cst_list']] items = [item for item in result['items'] if item['id'] in response.data['cst_list']]

View File

@ -1,8 +1,8 @@
''' Testing views ''' ''' Testing views '''
from rest_framework.exceptions import ErrorDetail
from rest_framework import status from rest_framework import status
from rest_framework.exceptions import ErrorDetail
from .EndpointTester import decl_endpoint, EndpointTester from .EndpointTester import EndpointTester, decl_endpoint
class TestRSLanguageViews(EndpointTester): class TestRSLanguageViews(EndpointTester):
@ -42,4 +42,3 @@ class TestRSLanguageViews(EndpointTester):
self.assertEqual(response.data['parseResult'], True) self.assertEqual(response.data['parseResult'], True)
self.assertEqual(response.data['syntax'], 'math') self.assertEqual(response.data['syntax'], 'math')
self.assertEqual(response.data['astText'], '[=[1][1]]') self.assertEqual(response.data['astText'], '[=[1][1]]')

View File

@ -1,17 +1,19 @@
''' Testing API: Versions. ''' ''' Testing API: Versions. '''
import io import io
from typing import cast
from sys import version from sys import version
from typing import cast
from zipfile import ZipFile from zipfile import ZipFile
from rest_framework import status from rest_framework import status
from apps.rsform.models import RSForm, Constituenta from apps.rsform.models import Constituenta, RSForm
from .EndpointTester import decl_endpoint, EndpointTester from .EndpointTester import EndpointTester, decl_endpoint
class TestVersionViews(EndpointTester): class TestVersionViews(EndpointTester):
''' Testing versioning endpoints. ''' ''' Testing versioning endpoints. '''
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.owned = RSForm.create(title='Test', alias='T1', owner=self.user).item self.owned = RSForm.create(title='Test', alias='T1', owner=self.user).item
@ -30,10 +32,10 @@ class TestVersionViews(EndpointTester):
data = {'version': '1.0.0', 'description': 'test'} data = {'version': '1.0.0', 'description': 'test'}
self.assertNotFound(data, schema=invalid_id) self.assertNotFound(data, schema=invalid_id)
self.assertForbidden(data, schema=self.unowned.id) self.assertForbidden(data, schema=self.unowned.pk)
self.assertBadData(invalid_data, schema=self.owned.id) self.assertBadData(invalid_data, schema=self.owned.pk)
response = self.execute(data, schema=self.owned.id) response = self.execute(data, schema=self.owned.pk)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertTrue('version' in response.data) self.assertTrue('version' in response.data)
self.assertTrue('schema' in response.data) self.assertTrue('schema' in response.data)
@ -46,16 +48,16 @@ class TestVersionViews(EndpointTester):
invalid_id = version_id + 1337 invalid_id = version_id + 1337
self.assertNotFound(schema=invalid_id, version=invalid_id) self.assertNotFound(schema=invalid_id, version=invalid_id)
self.assertNotFound(schema=self.owned.id, version=invalid_id) self.assertNotFound(schema=self.owned.pk, version=invalid_id)
self.assertNotFound(schema=invalid_id, version=version_id) self.assertNotFound(schema=invalid_id, version=version_id)
self.assertNotFound(schema=self.unowned.id, version=version_id) self.assertNotFound(schema=self.unowned.pk, version=version_id)
self.owned.alias = 'NewName' self.owned.alias = 'NewName'
self.owned.save() self.owned.save()
self.x1.alias = 'X33' self.x1.alias = 'X33'
self.x1.save() self.x1.save()
response = self.execute(schema=self.owned.id, version=version_id) response = self.execute(schema=self.owned.pk, version=version_id)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotEqual(response.data['alias'], self.owned.alias) self.assertNotEqual(response.data['alias'], self.owned.alias)
self.assertNotEqual(response.data['items'][0]['alias'], self.x1.alias) self.assertNotEqual(response.data['items'][0]['alias'], self.x1.alias)
@ -76,7 +78,7 @@ class TestVersionViews(EndpointTester):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['version'], data['version']) self.assertEqual(response.data['version'], data['version'])
self.assertEqual(response.data['description'], data['description']) self.assertEqual(response.data['description'], data['description'])
self.assertEqual(response.data['item'], self.owned.id) self.assertEqual(response.data['item'], self.owned.pk)
data = {'version': '1.2.0', 'description': 'test1'} data = {'version': '1.2.0', 'description': 'test1'}
self.method = 'patch' self.method = 'patch'
@ -111,7 +113,7 @@ class TestVersionViews(EndpointTester):
a1.definition_formal = 'X1=X2' a1.definition_formal = 'X1=X2'
a1.save() a1.save()
response = self.get(schema=self.owned.id, version=version_id) response = self.get(schema=self.owned.pk, version=version_id)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
loaded_a1 = response.data['items'][1] loaded_a1 = response.data['items'][1]
self.assertEqual(loaded_a1['definition_formal'], 'X1=X1') self.assertEqual(loaded_a1['definition_formal'], 'X1=X1')
@ -171,7 +173,7 @@ class TestVersionViews(EndpointTester):
def _create_version(self, data) -> int: def _create_version(self, data) -> int:
response = self.client.post( response = self.client.post(
f'/api/rsforms/{self.owned.id}/versions/create', f'/api/rsforms/{self.owned.pk}/versions/create',
data=data, format='json' data=data, format='json'
) )
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)

View File

@ -79,10 +79,13 @@ class LibraryViewSet(viewsets.ModelViewSet):
return serializer.save() return serializer.save()
def get_permissions(self): def get_permissions(self):
if self.action in ['destroy']: if self.action in ['update', 'partial_update']:
permission_list = [permissions.ItemOwner]
elif self.action in ['update', 'partial_update']:
permission_list = [permissions.ItemEditor] permission_list = [permissions.ItemEditor]
elif self.action in [
'destroy', 'set_owner',
'editors_add', 'editors_remove', 'editors_set'
]:
permission_list = [permissions.ItemOwner]
elif self.action in ['create', 'clone', 'subscribe', 'unsubscribe']: elif self.action in ['create', 'clone', 'subscribe', 'unsubscribe']:
permission_list = [permissions.GlobalUser] permission_list = [permissions.GlobalUser]
else: else:
@ -139,7 +142,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
tags=['Library'], tags=['Library'],
request=None, request=None,
responses={ responses={
c.HTTP_204_NO_CONTENT: None, c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None, c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None c.HTTP_404_NOT_FOUND: None
} }
@ -149,14 +152,14 @@ class LibraryViewSet(viewsets.ModelViewSet):
''' Endpoint: Subscribe current user to item. ''' ''' Endpoint: Subscribe current user to item. '''
item = self._get_item() item = self._get_item()
m.Subscription.subscribe(user=cast(m.User, self.request.user), item=item) m.Subscription.subscribe(user=cast(m.User, self.request.user), item=item)
return Response(status=c.HTTP_204_NO_CONTENT) return Response(status=c.HTTP_200_OK)
@extend_schema( @extend_schema(
summary='unsubscribe from item', summary='unsubscribe from item',
tags=['Library'], tags=['Library'],
request=None, request=None,
responses={ responses={
c.HTTP_204_NO_CONTENT: None, c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None, c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None c.HTTP_404_NOT_FOUND: None
}, },
@ -166,4 +169,84 @@ class LibraryViewSet(viewsets.ModelViewSet):
''' Endpoint: Unsubscribe current user from item. ''' ''' Endpoint: Unsubscribe current user from item. '''
item = self._get_item() item = self._get_item()
m.Subscription.unsubscribe(user=cast(m.User, self.request.user), item=item) m.Subscription.unsubscribe(user=cast(m.User, self.request.user), item=item)
return Response(status=c.HTTP_204_NO_CONTENT) return Response(status=c.HTTP_200_OK)
@extend_schema(
summary='set owner for item',
tags=['Library'],
request=s.UserTargetSerializer,
responses={
c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None
}
)
@action(detail=True, methods=['patch'], url_path='set-owner')
def set_owner(self, request: Request, pk):
''' Endpoint: Set item owner. '''
item = self._get_item()
serializer = s.UserTargetSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
new_owner = serializer.validated_data['user']
m.LibraryItem.objects.filter(pk=item.pk).update(owner=new_owner)
return Response(status=c.HTTP_200_OK)
@extend_schema(
summary='add editor for item',
tags=['Library'],
request=s.UserTargetSerializer,
responses={
c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None
}
)
@action(detail=True, methods=['patch'], url_path='editors-add')
def editors_add(self, request: Request, pk):
''' Endpoint: Add editor for item. '''
item = self._get_item()
serializer = s.UserTargetSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
new_editor = serializer.validated_data['user']
m.Editor.add(item=item, user=new_editor)
return Response(status=c.HTTP_200_OK)
@extend_schema(
summary='remove editor for item',
tags=['Library'],
request=s.UserTargetSerializer,
responses={
c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None
}
)
@action(detail=True, methods=['patch'], url_path='editors-remove')
def editors_remove(self, request: Request, pk):
''' Endpoint: Remove editor for item. '''
item = self._get_item()
serializer = s.UserTargetSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
editor = serializer.validated_data['user']
m.Editor.remove(item=item, user=editor)
return Response(status=c.HTTP_200_OK)
@extend_schema(
summary='set list of editors for item',
tags=['Library'],
request=s.UsersListSerializer,
responses={
c.HTTP_200_OK: None,
c.HTTP_403_FORBIDDEN: None,
c.HTTP_404_NOT_FOUND: None
}
)
@action(detail=True, methods=['patch'], url_path='editors-set')
def editors_set(self, request: Request, pk):
''' Endpoint: Set list of editors for item. '''
item = self._get_item()
serializer = s.UsersListSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
editors = serializer.validated_data['users']
m.Editor.set(item=item, users=editors)
return Response(status=c.HTTP_200_OK)

View File

@ -32,8 +32,10 @@ class RSFormViewSet(viewsets.GenericViewSet, generics.ListAPIView, generics.Retr
def get_permissions(self): def get_permissions(self):
''' Determine permission class. ''' ''' Determine permission class. '''
if self.action in ['load_trs', 'cst_create', 'cst_delete_multiple', if self.action in [
'reset_aliases', 'cst_rename', 'cst_substitute']: 'load_trs', 'cst_create', 'cst_delete_multiple',
'reset_aliases', 'cst_rename', 'cst_substitute'
]:
permission_list = [permissions.ItemOwner] permission_list = [permissions.ItemOwner]
else: else:
permission_list = [permissions.Anyone] permission_list = [permissions.Anyone]