pagination.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import inspect
  2. from abc import abstractmethod
  3. from functools import partial, wraps
  4. from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, List, Type, Union
  5. from urllib import parse
  6. from asgiref.sync import sync_to_async
  7. from django.http import HttpRequest, HttpResponse
  8. from django.utils.module_loading import import_string
  9. from ninja.conf import settings as ninja_settings
  10. from ninja.constants import NOT_SET
  11. from ninja.pagination import PaginationBase, make_response_paginated
  12. from ninja.utils import (
  13. contribute_operation_args,
  14. contribute_operation_callback,
  15. is_async_callable,
  16. )
  17. from .cursor_pagination import CursorPagination, _clamp, _reverse_order
  18. if TYPE_CHECKING:
  19. from django.db.models import QuerySet
  20. class AsyncPaginationBase(PaginationBase):
  21. @abstractmethod
  22. async def apaginate_queryset(
  23. self,
  24. queryset: "QuerySet",
  25. pagination: Any,
  26. **params: Any,
  27. ) -> Any:
  28. pass # pragma: no cover
  29. async def _aitems_count(self, queryset: "QuerySet") -> int:
  30. try:
  31. return await queryset.all().acount()
  32. except AttributeError:
  33. return len(queryset)
  34. class AsyncLinkHeaderPagination(CursorPagination):
  35. max_hits = 1000
  36. # Remove Output schema because we only want to return a list of items
  37. Output = None
  38. async def apaginate_queryset(
  39. self,
  40. queryset: "QuerySet",
  41. pagination: CursorPagination.Input,
  42. request: HttpRequest,
  43. response: HttpResponse,
  44. **params,
  45. ) -> dict:
  46. limit = _clamp(
  47. pagination.limit or ninja_settings.PAGINATION_PER_PAGE,
  48. 0,
  49. self.max_page_size,
  50. )
  51. full_queryset = queryset
  52. if not queryset.query.order_by:
  53. queryset = queryset.order_by(*self.default_ordering)
  54. order = queryset.query.order_by
  55. base_url = request.build_absolute_uri()
  56. cursor = pagination.cursor
  57. if cursor.reverse:
  58. queryset = queryset.order_by(*_reverse_order(order))
  59. if cursor.position is not None:
  60. is_reversed = order[0].startswith("-")
  61. order_attr = order[0].lstrip("-")
  62. if cursor.reverse != is_reversed:
  63. queryset = queryset.filter(**{f"{order_attr}__lt": cursor.position})
  64. else:
  65. queryset = queryset.filter(**{f"{order_attr}__gt": cursor.position})
  66. @sync_to_async
  67. def get_results():
  68. return list(queryset[cursor.offset : cursor.offset + limit + 1])
  69. results = await get_results()
  70. page = list(results[:limit])
  71. if len(results) > len(page):
  72. has_following_position = True
  73. following_position = self._get_position_from_instance(results[-1], order)
  74. else:
  75. has_following_position = False
  76. following_position = None
  77. if cursor.reverse:
  78. page = list(reversed(page))
  79. has_next = (cursor.position is not None) or (cursor.offset > 0)
  80. has_previous = has_following_position
  81. next_position = cursor.position if has_next else None
  82. previous_position = following_position if has_previous else None
  83. else:
  84. has_next = has_following_position
  85. has_previous = (cursor.position is not None) or (cursor.offset > 0)
  86. next_position = following_position if has_next else None
  87. previous_position = cursor.position if has_previous else None
  88. next = (
  89. self.next_link(
  90. base_url,
  91. page,
  92. cursor,
  93. order,
  94. has_previous,
  95. limit,
  96. next_position,
  97. previous_position,
  98. )
  99. if has_next
  100. else None
  101. )
  102. previous = (
  103. self.previous_link(
  104. base_url,
  105. page,
  106. cursor,
  107. order,
  108. has_next,
  109. limit,
  110. next_position,
  111. previous_position,
  112. )
  113. if has_previous
  114. else None
  115. )
  116. total_count = 0
  117. if has_next or has_previous:
  118. total_count = await self._aitems_count(full_queryset)
  119. else:
  120. total_count = len(page)
  121. links = []
  122. for url, label in (
  123. (previous, "previous"),
  124. (next, "next"),
  125. ):
  126. if url is not None:
  127. parsed = parse.urlparse(url)
  128. cursor = parse.parse_qs(parsed.query).get("cursor", [""])[0]
  129. links.append(
  130. '<{}>; rel="{}"; results="true"; cursor="{}"'.format(
  131. url, label, cursor
  132. )
  133. )
  134. else:
  135. links.append('<{}>; rel="{}"; results="false"'.format(base_url, label))
  136. response["Link"] = {", ".join(links)} if links else {}
  137. response["X-Max-Hits"] = self.max_hits
  138. response["X-Hits"] = total_count
  139. return page
  140. async def _aitems_count(self, queryset: "QuerySet") -> int:
  141. return await queryset.order_by()[: self.max_hits].acount() # type: ignore
  142. def _inject_pagination(
  143. func: Callable,
  144. paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
  145. **paginator_params: Any,
  146. ) -> Callable:
  147. paginator = paginator_class(**paginator_params)
  148. if is_async_callable(func):
  149. @wraps(func)
  150. async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
  151. pagination_params = kwargs.pop("ninja_pagination")
  152. if paginator.pass_parameter:
  153. kwargs[paginator.pass_parameter] = pagination_params
  154. items = await func(request, **kwargs)
  155. result = await paginator.apaginate_queryset(
  156. items, pagination=pagination_params, request=request, **kwargs
  157. )
  158. async def evaluate(results: Union[List, "QuerySet"]) -> AsyncGenerator:
  159. for result in results:
  160. yield result
  161. if paginator.Output: # type: ignore
  162. result[paginator.items_attribute] = [
  163. result
  164. async for result in evaluate(result[paginator.items_attribute])
  165. ]
  166. return result
  167. else:
  168. @wraps(func)
  169. def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
  170. pagination_params = kwargs.pop("ninja_pagination")
  171. if paginator.pass_parameter:
  172. kwargs[paginator.pass_parameter] = pagination_params
  173. items = func(request, **kwargs)
  174. result = paginator.paginate_queryset(
  175. items, pagination=pagination_params, request=request, **kwargs
  176. )
  177. if paginator.Output: # type: ignore
  178. result[paginator.items_attribute] = list(
  179. result[paginator.items_attribute]
  180. )
  181. # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
  182. return result
  183. contribute_operation_args(
  184. view_with_pagination,
  185. "ninja_pagination",
  186. paginator.Input,
  187. paginator.InputSource,
  188. )
  189. if paginator.Output: # type: ignore
  190. contribute_operation_callback(
  191. view_with_pagination,
  192. partial(make_response_paginated, paginator),
  193. )
  194. return view_with_pagination
  195. def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
  196. """
  197. @api.get(...
  198. @paginate
  199. def my_view(request):
  200. or
  201. @api.get(...
  202. @paginate(PageNumberPagination)
  203. def my_view(request):
  204. """
  205. isfunction = inspect.isfunction(func_or_pgn_class)
  206. isnotset = func_or_pgn_class == NOT_SET
  207. pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string(
  208. ninja_settings.PAGINATION_CLASS
  209. )
  210. if isfunction:
  211. return _inject_pagination(func_or_pgn_class, pagination_class)
  212. if not isnotset:
  213. pagination_class = func_or_pgn_class
  214. def wrapper(func: Callable) -> Any:
  215. return _inject_pagination(func, pagination_class, **paginator_params)
  216. return wrapper