mirror of
				https://gitlab.crans.org/bde/nk20
				synced 2025-11-04 01:12:08 +01:00 
			
		
		
		
	
		
			
				
	
	
		
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import re
 | 
						|
from functools import lru_cache
 | 
						|
 | 
						|
from rest_framework.filters import SearchFilter
 | 
						|
 | 
						|
 | 
						|
class RegexSafeSearchFilter(SearchFilter):
 | 
						|
    @lru_cache
 | 
						|
    def validate_regex(self, search_term) -> bool:
 | 
						|
        try:
 | 
						|
            re.compile(search_term)
 | 
						|
            return True
 | 
						|
        except re.error:
 | 
						|
            return False
 | 
						|
 | 
						|
    def get_search_fields(self, view, request):
 | 
						|
        """
 | 
						|
        Ensure that given regex are valid.
 | 
						|
        If not, we consider that the user is trying to search by substring.
 | 
						|
        """
 | 
						|
        search_fields = super().get_search_fields(view, request)
 | 
						|
        search_terms = self.get_search_terms(request)
 | 
						|
 | 
						|
        for search_term in search_terms:
 | 
						|
            if not self.validate_regex(search_term):
 | 
						|
                # Invalid regex. We assume we don't query by regex but by substring.
 | 
						|
                search_fields = [f.replace('$', '') for f in search_fields]
 | 
						|
                break
 | 
						|
 | 
						|
        return search_fields
 | 
						|
 | 
						|
    def get_search_terms(self, request):
 | 
						|
        """
 | 
						|
        Ensure that search field is a valid regex query. If not, we remove extra characters.
 | 
						|
        """
 | 
						|
        terms = super().get_search_terms(request)
 | 
						|
        if not all(self.validate_regex(term) for term in terms):
 | 
						|
            # Invalid regex. If a ^ is prefixed to the search term, we remove it.
 | 
						|
            terms = [term[1:] if term[0] == '^' else term for term in terms]
 | 
						|
            # Same for dollars.
 | 
						|
            terms = [term[:-1] if term[-1] == '$' else term for term in terms]
 | 
						|
        return terms
 |