import inspect from abc import abstractmethod from functools import partial, wraps from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, List, Type, Union from urllib import parse from asgiref.sync import sync_to_async from django.http import HttpRequest, HttpResponse from django.utils.module_loading import import_string from ninja.conf import settings as ninja_settings from ninja.constants import NOT_SET from ninja.pagination import PaginationBase, make_response_paginated from ninja.utils import ( contribute_operation_args, contribute_operation_callback, is_async_callable, ) from .cursor_pagination import CursorPagination, _clamp, _reverse_order if TYPE_CHECKING: from django.db.models import QuerySet class AsyncPaginationBase(PaginationBase): @abstractmethod async def apaginate_queryset( self, queryset: "QuerySet", pagination: Any, **params: Any, ) -> Any: pass # pragma: no cover async def _aitems_count(self, queryset: "QuerySet") -> int: try: return await queryset.all().acount() except AttributeError: return len(queryset) class AsyncLinkHeaderPagination(CursorPagination): max_hits = 1000 # Remove Output schema because we only want to return a list of items Output = None async def apaginate_queryset( self, queryset: "QuerySet", pagination: CursorPagination.Input, request: HttpRequest, response: HttpResponse, **params, ) -> dict: limit = _clamp( pagination.limit or ninja_settings.PAGINATION_PER_PAGE, 0, self.max_page_size, ) full_queryset = queryset if not queryset.query.order_by: queryset = queryset.order_by(*self.default_ordering) order = queryset.query.order_by base_url = request.build_absolute_uri() cursor = pagination.cursor if cursor.reverse: queryset = queryset.order_by(*_reverse_order(order)) if cursor.position is not None: is_reversed = order[0].startswith("-") order_attr = order[0].lstrip("-") if cursor.reverse != is_reversed: queryset = queryset.filter(**{f"{order_attr}__lt": cursor.position}) else: queryset = queryset.filter(**{f"{order_attr}__gt": cursor.position}) @sync_to_async def get_results(): return list(queryset[cursor.offset : cursor.offset + limit + 1]) results = await get_results() page = list(results[:limit]) if len(results) > len(page): has_following_position = True following_position = self._get_position_from_instance(results[-1], order) else: has_following_position = False following_position = None if cursor.reverse: page = list(reversed(page)) has_next = (cursor.position is not None) or (cursor.offset > 0) has_previous = has_following_position next_position = cursor.position if has_next else None previous_position = following_position if has_previous else None else: has_next = has_following_position has_previous = (cursor.position is not None) or (cursor.offset > 0) next_position = following_position if has_next else None previous_position = cursor.position if has_previous else None next = ( self.next_link( base_url, page, cursor, order, has_previous, limit, next_position, previous_position, ) if has_next else None ) previous = ( self.previous_link( base_url, page, cursor, order, has_next, limit, next_position, previous_position, ) if has_previous else None ) total_count = 0 if has_next or has_previous: total_count = await self._aitems_count(full_queryset) else: total_count = len(page) links = [] for url, label in ( (previous, "previous"), (next, "next"), ): if url is not None: parsed = parse.urlparse(url) cursor = parse.parse_qs(parsed.query).get("cursor", [""])[0] links.append( '<{}>; rel="{}"; results="true"; cursor="{}"'.format( url, label, cursor ) ) else: links.append('<{}>; rel="{}"; results="false"'.format(base_url, label)) response["Link"] = {", ".join(links)} if links else {} response["X-Max-Hits"] = self.max_hits response["X-Hits"] = total_count return page async def _aitems_count(self, queryset: "QuerySet") -> int: return await queryset.order_by()[: self.max_hits].acount() # type: ignore def _inject_pagination( func: Callable, paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]], **paginator_params: Any, ) -> Callable: paginator = paginator_class(**paginator_params) if is_async_callable(func): @wraps(func) async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: pagination_params = kwargs.pop("ninja_pagination") if paginator.pass_parameter: kwargs[paginator.pass_parameter] = pagination_params items = await func(request, **kwargs) result = await paginator.apaginate_queryset( items, pagination=pagination_params, request=request, **kwargs ) async def evaluate(results: Union[List, "QuerySet"]) -> AsyncGenerator: for result in results: yield result if paginator.Output: # type: ignore result[paginator.items_attribute] = [ result async for result in evaluate(result[paginator.items_attribute]) ] return result else: @wraps(func) def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any: pagination_params = kwargs.pop("ninja_pagination") if paginator.pass_parameter: kwargs[paginator.pass_parameter] = pagination_params items = func(request, **kwargs) result = paginator.paginate_queryset( items, pagination=pagination_params, request=request, **kwargs ) if paginator.Output: # type: ignore result[paginator.items_attribute] = list( result[paginator.items_attribute] ) # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here return result contribute_operation_args( view_with_pagination, "ninja_pagination", paginator.Input, paginator.InputSource, ) if paginator.Output: # type: ignore contribute_operation_callback( view_with_pagination, partial(make_response_paginated, paginator), ) return view_with_pagination def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable: """ @api.get(... @paginate def my_view(request): or @api.get(... @paginate(PageNumberPagination) def my_view(request): """ isfunction = inspect.isfunction(func_or_pgn_class) isnotset = func_or_pgn_class == NOT_SET pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string( ninja_settings.PAGINATION_CLASS ) if isfunction: return _inject_pagination(func_or_pgn_class, pagination_class) if not isnotset: pagination_class = func_or_pgn_class def wrapper(func: Callable) -> Any: return _inject_pagination(func, pagination_class, **paginator_params) return wrapper