R: Optimize database queries
Some checks failed
Backend CI / build (3.12) (push) Has been cancelled
Frontend CI / build (22.x) (push) Has been cancelled

This commit is contained in:
Ivan 2024-08-06 23:13:57 +03:00
parent 91642816b1
commit 513d6a5b71
9 changed files with 43 additions and 50 deletions

View File

@ -55,7 +55,7 @@ class Editor(Model):
def set(item: int, users: Iterable[int]): def set(item: int, users: Iterable[int]):
''' Set editors for item. ''' ''' Set editors for item. '''
processed: set[int] = set() processed: set[int] = set()
for editor_item in Editor.objects.filter(item_id=item).only('pk', 'editor_id'): for editor_item in Editor.objects.filter(item_id=item).only('editor_id'):
editor_id = editor_item.editor_id editor_id = editor_item.editor_id
if editor_id not in users: if editor_id not in users:
editor_item.delete() editor_item.delete()
@ -74,7 +74,7 @@ class Editor(Model):
processed: list[int] = [] processed: list[int] = []
deleted: list[int] = [] deleted: list[int] = []
added: list[int] = [] added: list[int] = []
for editor_item in Editor.objects.filter(item_id=item).only('pk', 'editor_id'): for editor_item in Editor.objects.filter(item_id=item).only('editor_id'):
editor_id = editor_item.editor_id editor_id = editor_item.editor_id
if editor_id not in users: if editor_id not in users:
deleted.append(editor_id) deleted.append(editor_id)

View File

@ -114,9 +114,9 @@ class LibraryItem(Model):
def get_absolute_url(self): def get_absolute_url(self):
return f'/api/library/{self.pk}' return f'/api/library/{self.pk}'
def subscribers(self) -> list[User]: def subscribers(self) -> QuerySet[User]:
''' Get all subscribers for this item. ''' ''' Get all subscribers for this item. '''
return [subscription.user for subscription in Subscription.objects.filter(item=self.pk).only('user')] return User.objects.filter(subscription__item=self.pk)
def editors(self) -> QuerySet[User]: def editors(self) -> QuerySet[User]:
''' Get all Editors of this item. ''' ''' Get all Editors of this item. '''
@ -126,6 +126,7 @@ class LibraryItem(Model):
''' Get all Versions of this item. ''' ''' Get all Versions of this item. '''
return Version.objects.filter(item=self.pk).order_by('-time_create') return Version.objects.filter(item=self.pk).order_by('-time_create')
# TODO: move to View layer
@transaction.atomic @transaction.atomic
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
''' Save updating subscriptions and connected operations. ''' ''' Save updating subscriptions and connected operations. '''
@ -134,7 +135,7 @@ class LibraryItem(Model):
subscribe = self._state.adding and self.owner subscribe = self._state.adding and self.owner
super().save(*args, **kwargs) super().save(*args, **kwargs)
if subscribe: if subscribe:
Subscription.subscribe(user=self.owner, item=self) Subscription.subscribe(user=self.owner_id, item=self.pk)
def _update_connected_operations(self): def _update_connected_operations(self):
# using method level import to prevent circular dependency # using method level import to prevent circular dependency

View File

@ -1,13 +1,8 @@
''' Models: Subscription. ''' ''' Models: Subscription. '''
from typing import TYPE_CHECKING
from django.db.models import CASCADE, ForeignKey, Model from django.db.models import CASCADE, ForeignKey, Model
from apps.users.models import User from apps.users.models import User
if TYPE_CHECKING:
from .LibraryItem import LibraryItem
class Subscription(Model): class Subscription(Model):
''' User subscription to library item. ''' ''' User subscription to library item. '''
@ -32,17 +27,17 @@ class Subscription(Model):
return f'{self.user} -> {self.item}' return f'{self.user} -> {self.item}'
@staticmethod @staticmethod
def subscribe(user: User, item: 'LibraryItem') -> bool: def subscribe(user: int, item: int) -> bool:
''' Add subscription. ''' ''' Add subscription. '''
if Subscription.objects.filter(user=user, item=item).exists(): if Subscription.objects.filter(user_id=user, item_id=item).exists():
return False return False
Subscription.objects.create(user=user, item=item) Subscription.objects.create(user_id=user, item_id=item)
return True return True
@staticmethod @staticmethod
def unsubscribe(user: User, item: 'LibraryItem') -> bool: def unsubscribe(user: int, item: int) -> bool:
''' Remove subscription. ''' ''' Remove subscription. '''
sub = Subscription.objects.filter(user=user, item=item) sub = Subscription.objects.filter(user_id=user, item_id=item).only('pk')
if not sub.exists(): if not sub.exists():
return False return False
sub.delete() sub.delete()

View File

@ -83,7 +83,7 @@ class LibraryItemDetailsSerializer(serializers.ModelSerializer):
read_only_fields = ('owner', 'id', 'item_type') read_only_fields = ('owner', 'id', 'item_type')
def get_subscribers(self, instance: LibraryItem) -> list[int]: def get_subscribers(self, instance: LibraryItem) -> list[int]:
return [item.pk for item in instance.subscribers()] return list(instance.subscribers().values_list('pk', flat=True))
def get_editors(self, instance: LibraryItem) -> list[int]: def get_editors(self, instance: LibraryItem) -> list[int]:
return list(instance.editors().values_list('pk', flat=True)) return list(instance.editors().values_list('pk', flat=True))

View File

@ -37,33 +37,33 @@ class TestSubscription(TestCase):
def test_subscribe(self): def test_subscribe(self):
item = LibraryItem.objects.create(item_type=LibraryItemType.RSFORM, title='Test') item = LibraryItem.objects.create(item_type=LibraryItemType.RSFORM, title='Test')
self.assertEqual(len(item.subscribers()), 0) self.assertEqual(item.subscribers().count(), 0)
self.assertTrue(Subscription.subscribe(self.user1, item)) self.assertTrue(Subscription.subscribe(self.user1.pk, item.pk))
self.assertEqual(len(item.subscribers()), 1) self.assertEqual(item.subscribers().count(), 1)
self.assertTrue(self.user1 in item.subscribers()) self.assertTrue(self.user1 in item.subscribers())
self.assertFalse(Subscription.subscribe(self.user1, item)) self.assertFalse(Subscription.subscribe(self.user1.pk, item.pk))
self.assertEqual(len(item.subscribers()), 1) self.assertEqual(item.subscribers().count(), 1)
self.assertTrue(Subscription.subscribe(self.user2, item)) self.assertTrue(Subscription.subscribe(self.user2.pk, item.pk))
self.assertEqual(len(item.subscribers()), 2) self.assertEqual(item.subscribers().count(), 2)
self.assertTrue(self.user1 in item.subscribers()) self.assertTrue(self.user1 in item.subscribers())
self.assertTrue(self.user2 in item.subscribers()) self.assertTrue(self.user2 in item.subscribers())
self.user1.delete() self.user1.delete()
self.assertEqual(len(item.subscribers()), 1) self.assertEqual(item.subscribers().count(), 1)
def test_unsubscribe(self): def test_unsubscribe(self):
item = LibraryItem.objects.create(item_type=LibraryItemType.RSFORM, title='Test') item = LibraryItem.objects.create(item_type=LibraryItemType.RSFORM, title='Test')
self.assertFalse(Subscription.unsubscribe(self.user1, item)) self.assertFalse(Subscription.unsubscribe(self.user1.pk, item.pk))
Subscription.subscribe(self.user1, item) Subscription.subscribe(self.user1.pk, item.pk)
Subscription.subscribe(self.user2, item) Subscription.subscribe(self.user2.pk, item.pk)
self.assertEqual(len(item.subscribers()), 2) self.assertEqual(item.subscribers().count(), 2)
self.assertTrue(Subscription.unsubscribe(self.user1, item)) self.assertTrue(Subscription.unsubscribe(self.user1.pk, item.pk))
self.assertEqual(len(item.subscribers()), 1) self.assertEqual(item.subscribers().count(), 1)
self.assertTrue(self.user2 in item.subscribers()) self.assertTrue(self.user2 in item.subscribers())
self.assertFalse(Subscription.unsubscribe(self.user1, item)) self.assertFalse(Subscription.unsubscribe(self.user1.pk, item.pk))

View File

@ -269,9 +269,9 @@ class TestLibraryViewset(EndpointTester):
response = self.executeOK() response = self.executeOK()
self.assertFalse(response_contains(response, self.unowned)) self.assertFalse(response_contains(response, self.unowned))
Subscription.subscribe(user=self.user, item=self.unowned) Subscription.subscribe(user=self.user.pk, item=self.unowned.pk)
Subscription.subscribe(user=self.user2, item=self.unowned) Subscription.subscribe(user=self.user2.pk, item=self.unowned.pk)
Subscription.subscribe(user=self.user2, item=self.owned) Subscription.subscribe(user=self.user2.pk, item=self.owned.pk)
response = self.executeOK() response = self.executeOK()
self.assertTrue(response_contains(response, self.unowned)) self.assertTrue(response_contains(response, self.unowned))

View File

@ -4,9 +4,8 @@ from typing import cast
from django.db import transaction from django.db import transaction
from django.db.models import Q from django.db.models import Q
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema, extend_schema_view from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import filters, generics from rest_framework import generics
from rest_framework import status as c from rest_framework import status as c
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import action from rest_framework.decorators import action
@ -28,10 +27,8 @@ from .. import serializers as s
class LibraryViewSet(viewsets.ModelViewSet): class LibraryViewSet(viewsets.ModelViewSet):
''' Endpoint: Library operations. ''' ''' Endpoint: Library operations. '''
queryset = m.LibraryItem.objects.all() queryset = m.LibraryItem.objects.all()
# TODO: consider using .only() for performance
filter_backends = (DjangoFilterBackend, filters.OrderingFilter)
filterset_fields = ['item_type', 'owner']
ordering_fields = ('item_type', 'owner', 'alias', 'title', 'time_update')
ordering = '-time_update' ordering = '-time_update'
def get_serializer_class(self): def get_serializer_class(self):
@ -128,7 +125,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
def subscribe(self, request: Request, pk): def subscribe(self, request: Request, pk):
''' Endpoint: Subscribe current user to item. ''' ''' Endpoint: Subscribe current user to item. '''
item = self._get_item() item = self._get_item()
m.Subscription.subscribe(user=cast(User, self.request.user), item=item) m.Subscription.subscribe(user=cast(int, self.request.user.pk), item=item.pk)
return Response(status=c.HTTP_200_OK) return Response(status=c.HTTP_200_OK)
@extend_schema( @extend_schema(
@ -145,7 +142,7 @@ class LibraryViewSet(viewsets.ModelViewSet):
def unsubscribe(self, request: Request, pk): def unsubscribe(self, request: Request, pk):
''' Endpoint: Unsubscribe current user from item. ''' ''' Endpoint: Unsubscribe current user from item. '''
item = self._get_item() item = self._get_item()
m.Subscription.unsubscribe(user=cast(User, self.request.user), item=item) m.Subscription.unsubscribe(user=cast(int, self.request.user.pk), item=item.pk)
return Response(status=c.HTTP_200_OK) return Response(status=c.HTTP_200_OK)
@extend_schema( @extend_schema(

View File

@ -48,7 +48,7 @@ class OperationCreateSerializer(serializers.Serializer):
create_schema = serializers.BooleanField(default=False, required=False) create_schema = serializers.BooleanField(default=False, required=False)
item_data = OperationCreateData() item_data = OperationCreateData()
arguments = PKField(many=True, queryset=Operation.objects.all(), required=False) arguments = PKField(many=True, queryset=Operation.objects.all().only('pk'), required=False)
positions = serializers.ListField( positions = serializers.ListField(
child=OperationPositionSerializer(), child=OperationPositionSerializer(),
@ -67,7 +67,7 @@ class OperationUpdateSerializer(serializers.Serializer):
target = PKField(many=False, queryset=Operation.objects.all()) target = PKField(many=False, queryset=Operation.objects.all())
item_data = OperationUpdateData() item_data = OperationUpdateData()
arguments = PKField(many=True, queryset=Operation.objects.all(), required=False) arguments = PKField(many=True, queryset=Operation.objects.all().only('oss_id', 'result_id'), required=False)
substitutions = serializers.ListField( substitutions = serializers.ListField(
child=SubstitutionSerializerBase(), child=SubstitutionSerializerBase(),
required=False required=False
@ -121,8 +121,8 @@ class OperationUpdateSerializer(serializers.Serializer):
class OperationTargetSerializer(serializers.Serializer): class OperationTargetSerializer(serializers.Serializer):
''' Serializer: Delete operation. ''' ''' Serializer: Target single operation. '''
target = PKField(many=False, queryset=Operation.objects.all()) target = PKField(many=False, queryset=Operation.objects.all().only('oss_id', 'result_id'))
positions = serializers.ListField( positions = serializers.ListField(
child=OperationPositionSerializer(), child=OperationPositionSerializer(),
default=[] default=[]

View File

@ -269,7 +269,7 @@ class CstRenameSerializer(serializers.Serializer):
class CstListSerializer(serializers.Serializer): class CstListSerializer(serializers.Serializer):
''' Serializer: List of constituents from one origin. ''' ''' Serializer: List of constituents from one origin. '''
items = PKField(many=True, queryset=Constituenta.objects.all()) items = PKField(many=True, queryset=Constituenta.objects.all().only('schema_id'))
def validate(self, attrs): def validate(self, attrs):
schema = cast(LibraryItem, self.context['schema']) schema = cast(LibraryItem, self.context['schema'])
@ -291,8 +291,8 @@ class CstMoveSerializer(CstListSerializer):
class SubstitutionSerializerBase(serializers.Serializer): class SubstitutionSerializerBase(serializers.Serializer):
''' Serializer: Basic substitution. ''' ''' Serializer: Basic substitution. '''
original = PKField(many=False, queryset=Constituenta.objects.only('alias', 'schema')) original = PKField(many=False, queryset=Constituenta.objects.only('alias', 'schema_id'))
substitution = PKField(many=False, queryset=Constituenta.objects.only('alias', 'schema')) substitution = PKField(many=False, queryset=Constituenta.objects.only('alias', 'schema_id'))
class CstSubstituteSerializer(serializers.Serializer): class CstSubstituteSerializer(serializers.Serializer):
@ -330,8 +330,8 @@ class CstSubstituteSerializer(serializers.Serializer):
class InlineSynthesisSerializer(serializers.Serializer): class InlineSynthesisSerializer(serializers.Serializer):
''' Serializer: Inline synthesis operation input. ''' ''' Serializer: Inline synthesis operation input. '''
receiver = PKField(many=False, queryset=LibraryItem.objects.all()) receiver = PKField(many=False, queryset=LibraryItem.objects.all().only('owner_id'))
source = PKField(many=False, queryset=LibraryItem.objects.all()) # type: ignore source = PKField(many=False, queryset=LibraryItem.objects.all().only('owner_id')) # type: ignore
items = PKField(many=True, queryset=Constituenta.objects.all()) items = PKField(many=True, queryset=Constituenta.objects.all())
substitutions = serializers.ListField( substitutions = serializers.ListField(
child=SubstitutionSerializerBase() child=SubstitutionSerializerBase()