pagination.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import inspect
  2. from abc import abstractmethod
  3. from functools import partial, wraps
  4. from typing import Any, AsyncGenerator, Callable, List, Type, Union
  5. from urllib import parse
  6. from asgiref.sync import sync_to_async
  7. from django.db.models import QuerySet
  8. from django.http import HttpRequest, HttpResponse
  9. from django.utils.module_loading import import_string
  10. from ninja.conf import settings as ninja_settings
  11. from ninja.constants import NOT_SET
  12. from ninja.pagination import PaginationBase, make_response_paginated
  13. from ninja.utils import (
  14. contribute_operation_args,
  15. contribute_operation_callback,
  16. is_async_callable,
  17. )
  18. from .cursor_pagination import CursorPagination, _clamp, _reverse_order
  19. class AsyncPaginationBase(PaginationBase):
  20. @abstractmethod
  21. async def apaginate_queryset(
  22. self,
  23. queryset: QuerySet,
  24. pagination: Any,
  25. **params: Any,
  26. ) -> Any:
  27. pass # pragma: no cover
  28. async def _aitems_count(self, queryset: QuerySet) -> int:
  29. try:
  30. return await queryset.all().acount()
  31. except AttributeError:
  32. return len(queryset)
  33. class AsyncLinkHeaderPagination(CursorPagination):
  34. max_hits = 1000
  35. # Remove Output schema because we only want to return a list of items
  36. Output = None
  37. async def apaginate_queryset(
  38. self,
  39. queryset: QuerySet,
  40. pagination: CursorPagination.Input,
  41. request: HttpRequest,
  42. response: HttpResponse,
  43. **params,
  44. ) -> dict:
  45. limit = _clamp(
  46. pagination.limit or ninja_settings.PAGINATION_PER_PAGE,
  47. 0,
  48. self.max_page_size,
  49. )
  50. full_queryset = queryset
  51. if not queryset.query.order_by:
  52. queryset = queryset.order_by(*self.default_ordering)
  53. order = queryset.query.order_by
  54. base_url = request.build_absolute_uri()
  55. cursor = pagination.cursor
  56. if cursor.reverse:
  57. queryset = queryset.order_by(*_reverse_order(order))
  58. if cursor.position is not None:
  59. is_reversed = order[0].startswith("-")
  60. order_attr = order[0].lstrip("-")
  61. if cursor.reverse != is_reversed:
  62. queryset = queryset.filter(**{f"{order_attr}__lt": cursor.position})
  63. else:
  64. queryset = queryset.filter(**{f"{order_attr}__gt": cursor.position})
  65. @sync_to_async
  66. def get_results():
  67. return list(queryset[cursor.offset : cursor.offset + limit + 1])
  68. results = await get_results()
  69. page = list(results[:limit])
  70. if len(results) > len(page):
  71. has_following_position = True
  72. following_position = self._get_position_from_instance(results[-1], order)
  73. else:
  74. has_following_position = False
  75. following_position = None
  76. if cursor.reverse:
  77. page = list(reversed(page))
  78. has_next = (cursor.position is not None) or (cursor.offset > 0)
  79. has_previous = has_following_position
  80. next_position = cursor.position if has_next else None
  81. previous_position = following_position if has_previous else None
  82. else:
  83. has_next = has_following_position
  84. has_previous = (cursor.position is not None) or (cursor.offset > 0)
  85. next_position = following_position if has_next else None
  86. previous_position = cursor.position if has_previous else None
  87. next = (
  88. self.next_link(
  89. base_url,
  90. page,
  91. cursor,
  92. order,
  93. has_previous,
  94. limit,
  95. next_position,
  96. previous_position,
  97. )
  98. if has_next
  99. else None
  100. )
  101. previous = (
  102. self.previous_link(
  103. base_url,
  104. page,
  105. cursor,
  106. order,
  107. has_next,
  108. limit,
  109. next_position,
  110. previous_position,
  111. )
  112. if has_previous
  113. else None
  114. )
  115. total_count = 0
  116. if has_next or has_previous:
  117. total_count = await self._aitems_count(full_queryset)
  118. else:
  119. total_count = len(page)
  120. links = []
  121. for url, label in (
  122. (previous, "previous"),
  123. (next, "next"),
  124. ):
  125. if url is not None:
  126. parsed = parse.urlparse(url)
  127. cursor = parse.parse_qs(parsed.query).get("cursor", [""])[0]
  128. links.append(
  129. '<{}>; rel="{}"; results="true"; cursor="{}"'.format(
  130. url, label, cursor
  131. )
  132. )
  133. else:
  134. links.append('<{}>; rel="{}"; results="false"'.format(base_url, label))
  135. response["Link"] = {", ".join(links)} if links else {}
  136. response["X-Max-Hits"] = self.max_hits
  137. response["X-Hits"] = total_count
  138. return page
  139. async def _aitems_count(self, queryset: QuerySet) -> int:
  140. return await queryset.order_by()[: self.max_hits].acount() # type: ignore
  141. def _inject_pagination(
  142. func: Callable,
  143. paginator_class: Type[Union[PaginationBase, AsyncPaginationBase]],
  144. **paginator_params: Any,
  145. ) -> Callable:
  146. paginator = paginator_class(**paginator_params)
  147. if is_async_callable(func):
  148. @wraps(func)
  149. async def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
  150. pagination_params = kwargs.pop("ninja_pagination")
  151. if paginator.pass_parameter:
  152. kwargs[paginator.pass_parameter] = pagination_params
  153. items = await func(request, **kwargs)
  154. result = await paginator.apaginate_queryset(
  155. items, pagination=pagination_params, request=request, **kwargs
  156. )
  157. async def evaluate(results: Union[List, QuerySet]) -> AsyncGenerator:
  158. for result in results:
  159. yield result
  160. if paginator.Output: # type: ignore
  161. result[paginator.items_attribute] = [
  162. result
  163. async for result in evaluate(result[paginator.items_attribute])
  164. ]
  165. return result
  166. else:
  167. @wraps(func)
  168. def view_with_pagination(request: HttpRequest, **kwargs: Any) -> Any:
  169. pagination_params = kwargs.pop("ninja_pagination")
  170. if paginator.pass_parameter:
  171. kwargs[paginator.pass_parameter] = pagination_params
  172. items = func(request, **kwargs)
  173. result = paginator.paginate_queryset(
  174. items, pagination=pagination_params, request=request, **kwargs
  175. )
  176. if paginator.Output: # type: ignore
  177. result[paginator.items_attribute] = list(
  178. result[paginator.items_attribute]
  179. )
  180. # ^ forcing queryset evaluation #TODO: check why pydantic did not do it here
  181. return result
  182. contribute_operation_args(
  183. view_with_pagination,
  184. "ninja_pagination",
  185. paginator.Input,
  186. paginator.InputSource,
  187. )
  188. if paginator.Output: # type: ignore
  189. contribute_operation_callback(
  190. view_with_pagination,
  191. partial(make_response_paginated, paginator),
  192. )
  193. return view_with_pagination
  194. def paginate(func_or_pgn_class: Any = NOT_SET, **paginator_params: Any) -> Callable:
  195. """
  196. @api.get(...
  197. @paginate
  198. def my_view(request):
  199. or
  200. @api.get(...
  201. @paginate(PageNumberPagination)
  202. def my_view(request):
  203. """
  204. isfunction = inspect.isfunction(func_or_pgn_class)
  205. isnotset = func_or_pgn_class == NOT_SET
  206. pagination_class: Type[Union[PaginationBase, AsyncPaginationBase]] = import_string(
  207. ninja_settings.PAGINATION_CLASS
  208. )
  209. if isfunction:
  210. return _inject_pagination(func_or_pgn_class, pagination_class)
  211. if not isnotset:
  212. pagination_class = func_or_pgn_class
  213. def wrapper(func: Callable) -> Any:
  214. return _inject_pagination(func, pagination_class, **paginator_params)
  215. return wrapper