mirror of
				https://gitlab.crans.org/bde/nk20
				synced 2025-11-04 09:12:11 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			242 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			9.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (C) 2018-2025 by BDE ENS Paris-Saclay
 | 
						|
# SPDX-License-Identifier: GPL-3.0-or-later
 | 
						|
 | 
						|
import json
 | 
						|
from datetime import datetime, date
 | 
						|
from decimal import Decimal
 | 
						|
from urllib.parse import quote_plus
 | 
						|
from warnings import warn
 | 
						|
 | 
						|
from django.contrib.auth.models import User
 | 
						|
from django.contrib.contenttypes.models import ContentType
 | 
						|
from django.db.models.fields.files import ImageFieldFile
 | 
						|
from django.test import TestCase
 | 
						|
from django_filters.rest_framework import DjangoFilterBackend
 | 
						|
from phonenumbers import PhoneNumber
 | 
						|
from rest_framework.filters import OrderingFilter
 | 
						|
from api.filters import RegexSafeSearchFilter
 | 
						|
from member.models import Membership, Club
 | 
						|
from note.models import NoteClub, NoteUser, Alias, Note
 | 
						|
from permission.models import PermissionMask, Permission, Role
 | 
						|
 | 
						|
from .viewsets import ContentTypeViewSet, UserViewSet
 | 
						|
 | 
						|
 | 
						|
class TestAPI(TestCase):
 | 
						|
    """
 | 
						|
    Load API pages and check that filters are working.
 | 
						|
    """
 | 
						|
    fixtures = ('initial', )
 | 
						|
 | 
						|
    def setUp(self) -> None:
 | 
						|
        self.user = User.objects.create_superuser(
 | 
						|
            username="adminapi",
 | 
						|
            password="adminapi",
 | 
						|
            email="adminapi@example.com",
 | 
						|
            last_name="Admin",
 | 
						|
            first_name="Admin",
 | 
						|
        )
 | 
						|
        self.client.force_login(self.user)
 | 
						|
 | 
						|
        sess = self.client.session
 | 
						|
        sess["permission_mask"] = 42
 | 
						|
        sess.save()
 | 
						|
 | 
						|
    def check_viewset(self, viewset, url):
 | 
						|
        """
 | 
						|
        This function should be called inside a unit test.
 | 
						|
        This loads the viewset and for each filter entry, it checks that the filter is running good.
 | 
						|
        """
 | 
						|
        resp = self.client.get(url + "?format=json")
 | 
						|
        self.assertEqual(resp.status_code, 200)
 | 
						|
 | 
						|
        model = viewset.serializer_class.Meta.model
 | 
						|
 | 
						|
        if not model.objects.exists():  # pragma: no cover
 | 
						|
            warn(f"Warning: unable to test API filters for the model {model._meta.verbose_name} "
 | 
						|
                 "since there is no instance of it.")
 | 
						|
            return
 | 
						|
 | 
						|
        if hasattr(viewset, "filter_backends"):
 | 
						|
            backends = viewset.filter_backends
 | 
						|
            obj = model.objects.last()
 | 
						|
 | 
						|
            if DjangoFilterBackend in backends:
 | 
						|
                # Specific search
 | 
						|
                for field in viewset.filterset_fields:
 | 
						|
                    obj = self.fix_note_object(obj, field)
 | 
						|
 | 
						|
                    value = self.get_value(obj, field)
 | 
						|
                    if value is None:  # pragma: no cover
 | 
						|
                        warn(f"Warning: the filter {field} for the model {model._meta.verbose_name} "
 | 
						|
                             "has not been tested.")
 | 
						|
                        continue
 | 
						|
                    resp = self.client.get(url + f"?format=json&{field}={quote_plus(str(value))}")
 | 
						|
                    self.assertEqual(resp.status_code, 200, f"The filter {field} for the model "
 | 
						|
                                                            f"{model._meta.verbose_name} does not work. "
 | 
						|
                                                            f"Given parameter: {value}")
 | 
						|
                    content = json.loads(resp.content)
 | 
						|
                    self.assertGreater(content["count"], 0, f"The filter {field} for the model "
 | 
						|
                                                            f"{model._meta.verbose_name} does not work. "
 | 
						|
                                                            f"Given parameter: {value}")
 | 
						|
 | 
						|
            if OrderingFilter in backends:
 | 
						|
                # Ensure that ordering is working well
 | 
						|
                for field in viewset.ordering_fields:
 | 
						|
                    resp = self.client.get(url + f"?ordering={field}")
 | 
						|
                    self.assertEqual(resp.status_code, 200)
 | 
						|
                    resp = self.client.get(url + f"?ordering=-{field}")
 | 
						|
                    self.assertEqual(resp.status_code, 200)
 | 
						|
 | 
						|
            if RegexSafeSearchFilter in backends:
 | 
						|
                # Basic search
 | 
						|
                for field in viewset.search_fields:
 | 
						|
                    obj = self.fix_note_object(obj, field)
 | 
						|
 | 
						|
                    if field[0] == '$' or field[0] == '=':
 | 
						|
                        field = field[1:]
 | 
						|
                    value = self.get_value(obj, field)
 | 
						|
                    if value is None:  # pragma: no cover
 | 
						|
                        warn(f"Warning: the filter {field} for the model {model._meta.verbose_name} "
 | 
						|
                             "has not been tested.")
 | 
						|
                        continue
 | 
						|
                    resp = self.client.get(url + f"?format=json&search={quote_plus(str(value))}")
 | 
						|
                    self.assertEqual(resp.status_code, 200, f"The filter {field} for the model "
 | 
						|
                                                            f"{model._meta.verbose_name} does not work. "
 | 
						|
                                                            f"Given parameter: {value}")
 | 
						|
                    content = json.loads(resp.content)
 | 
						|
                    self.assertGreater(content["count"], 0, f"The filter {field} for the model "
 | 
						|
                                                            f"{model._meta.verbose_name} does not work. "
 | 
						|
                                                            f"Given parameter: {value}")
 | 
						|
 | 
						|
            self.check_permissions(url, obj)
 | 
						|
 | 
						|
    def check_permissions(self, url, obj):
 | 
						|
        """
 | 
						|
        Check that permissions are working
 | 
						|
        """
 | 
						|
        # Drop rights
 | 
						|
        self.user.is_superuser = False
 | 
						|
        self.user.save()
 | 
						|
        sess = self.client.session
 | 
						|
        sess["permission_mask"] = 0
 | 
						|
        sess.save()
 | 
						|
 | 
						|
        # Delete user permissions
 | 
						|
        for m in Membership.objects.filter(user=self.user).all():
 | 
						|
            m.roles.clear()
 | 
						|
            m.save()
 | 
						|
 | 
						|
        # Create a new role, which will have the checking permission
 | 
						|
        role = Role.objects.get_or_create(name="β-tester")[0]
 | 
						|
        role.permissions.clear()
 | 
						|
        role.save()
 | 
						|
        membership = Membership.objects.get_or_create(user=self.user, club=Club.objects.get(name="BDE"))[0]
 | 
						|
        membership.roles.set([role])
 | 
						|
        membership.save()
 | 
						|
 | 
						|
        # Ensure that the access to the object is forbidden without permission
 | 
						|
        resp = self.client.get(url + f"{obj.pk}/")
 | 
						|
        self.assertEqual(resp.status_code, 404, f"Mysterious access to {url}{obj.pk}/ for {obj}")
 | 
						|
 | 
						|
        obj.refresh_from_db()
 | 
						|
 | 
						|
        # There are problems with polymorphism
 | 
						|
        if isinstance(obj, Note) and hasattr(obj, "note_ptr"):
 | 
						|
            obj = obj.note_ptr
 | 
						|
 | 
						|
        mask = PermissionMask.objects.get(rank=0)
 | 
						|
 | 
						|
        for field in obj._meta.fields:
 | 
						|
            # Build permission query
 | 
						|
            value = self.get_value(obj, field.name)
 | 
						|
            if isinstance(value, date) or isinstance(value, datetime):
 | 
						|
                value = value.isoformat()
 | 
						|
            elif isinstance(value, ImageFieldFile):
 | 
						|
                value = value.name
 | 
						|
            elif isinstance(value, Decimal):
 | 
						|
                value = str(value)
 | 
						|
            query = json.dumps({field.name: value})
 | 
						|
 | 
						|
            # Create sample permission
 | 
						|
            permission = Permission.objects.get_or_create(
 | 
						|
                model=ContentType.objects.get_for_model(obj._meta.model),
 | 
						|
                query=query,
 | 
						|
                mask=mask,
 | 
						|
                type="view",
 | 
						|
                permanent=False,
 | 
						|
                description=f"Can view {obj._meta.verbose_name}",
 | 
						|
            )[0]
 | 
						|
            role.permissions.set([permission])
 | 
						|
            role.save()
 | 
						|
 | 
						|
            # Check that the access is possible
 | 
						|
            resp = self.client.get(url + f"{obj.pk}/")
 | 
						|
            self.assertEqual(resp.status_code, 200, f"Permission {permission.query} is not working "
 | 
						|
                                                    f"for the model {obj._meta.verbose_name}")
 | 
						|
 | 
						|
        # Restore rights
 | 
						|
        self.user.is_superuser = True
 | 
						|
        self.user.save()
 | 
						|
        sess = self.client.session
 | 
						|
        sess["permission_mask"] = 42
 | 
						|
        sess.save()
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def get_value(obj, key: str):
 | 
						|
        """
 | 
						|
        Resolve the queryset filter to get the Python value of an object.
 | 
						|
        """
 | 
						|
        if hasattr(obj, "all"):
 | 
						|
            # obj is a RelatedManager
 | 
						|
            obj = obj.last()
 | 
						|
 | 
						|
        if obj is None:  # pragma: no cover
 | 
						|
            return None
 | 
						|
 | 
						|
        if '__' not in key:
 | 
						|
            obj = getattr(obj, key)
 | 
						|
            if hasattr(obj, "pk"):
 | 
						|
                return obj.pk
 | 
						|
            elif hasattr(obj, "all"):
 | 
						|
                if not obj.exists():  # pragma: no cover
 | 
						|
                    return None
 | 
						|
                return obj.last().pk
 | 
						|
            elif isinstance(obj, bool):
 | 
						|
                return int(obj)
 | 
						|
            elif isinstance(obj, datetime):
 | 
						|
                return obj.isoformat()
 | 
						|
            elif isinstance(obj, PhoneNumber):
 | 
						|
                return obj.raw_input
 | 
						|
            return obj
 | 
						|
 | 
						|
        key, remaining = key.split('__', 1)
 | 
						|
        return TestAPI.get_value(getattr(obj, key), remaining)
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def fix_note_object(obj, field):
 | 
						|
        """
 | 
						|
        When querying an object that has a noteclub or a noteuser field,
 | 
						|
        ensure that the object has a good value.
 | 
						|
        """
 | 
						|
        if isinstance(obj, Alias):
 | 
						|
            if "noteuser" in field:
 | 
						|
                return NoteUser.objects.last().alias.last()
 | 
						|
            elif "noteclub" in field:
 | 
						|
                return NoteClub.objects.last().alias.last()
 | 
						|
        elif isinstance(obj, Note):
 | 
						|
            if "noteuser" in field:
 | 
						|
                return NoteUser.objects.last()
 | 
						|
            elif "noteclub" in field:
 | 
						|
                return NoteClub.objects.last()
 | 
						|
        return obj
 | 
						|
 | 
						|
 | 
						|
class TestBasicAPI(TestAPI):
 | 
						|
    def test_user_api(self):
 | 
						|
        """
 | 
						|
        Load the user page.
 | 
						|
        """
 | 
						|
        self.check_viewset(ContentTypeViewSet, "/api/models/")
 | 
						|
        self.check_viewset(UserViewSet, "/api/user/")
 |