payload.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import asyncio
  2. import enum
  3. import io
  4. import json
  5. import mimetypes
  6. import os
  7. import warnings
  8. from abc import ABC, abstractmethod
  9. from itertools import chain
  10. from typing import (
  11. IO,
  12. TYPE_CHECKING,
  13. Any,
  14. ByteString,
  15. Dict,
  16. Final,
  17. Iterable,
  18. Optional,
  19. TextIO,
  20. Tuple,
  21. Type,
  22. Union,
  23. )
  24. from multidict import CIMultiDict
  25. from . import hdrs
  26. from .abc import AbstractStreamWriter
  27. from .helpers import (
  28. _SENTINEL,
  29. content_disposition_header,
  30. guess_filename,
  31. parse_mimetype,
  32. sentinel,
  33. )
  34. from .streams import StreamReader
  35. from .typedefs import JSONEncoder, _CIMultiDict
  36. __all__ = (
  37. "PAYLOAD_REGISTRY",
  38. "get_payload",
  39. "payload_type",
  40. "Payload",
  41. "BytesPayload",
  42. "StringPayload",
  43. "IOBasePayload",
  44. "BytesIOPayload",
  45. "BufferedReaderPayload",
  46. "TextIOPayload",
  47. "StringIOPayload",
  48. "JsonPayload",
  49. "AsyncIterablePayload",
  50. )
  51. TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB
  52. if TYPE_CHECKING:
  53. from typing import List
  54. class LookupError(Exception):
  55. pass
  56. class Order(str, enum.Enum):
  57. normal = "normal"
  58. try_first = "try_first"
  59. try_last = "try_last"
  60. def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
  61. return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
  62. def register_payload(
  63. factory: Type["Payload"], type: Any, *, order: Order = Order.normal
  64. ) -> None:
  65. PAYLOAD_REGISTRY.register(factory, type, order=order)
  66. class payload_type:
  67. def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
  68. self.type = type
  69. self.order = order
  70. def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
  71. register_payload(factory, self.type, order=self.order)
  72. return factory
  73. PayloadType = Type["Payload"]
  74. _PayloadRegistryItem = Tuple[PayloadType, Any]
  75. class PayloadRegistry:
  76. """Payload registry.
  77. note: we need zope.interface for more efficient adapter search
  78. """
  79. def __init__(self) -> None:
  80. self._first: List[_PayloadRegistryItem] = []
  81. self._normal: List[_PayloadRegistryItem] = []
  82. self._last: List[_PayloadRegistryItem] = []
  83. def get(
  84. self,
  85. data: Any,
  86. *args: Any,
  87. _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
  88. **kwargs: Any,
  89. ) -> "Payload":
  90. if isinstance(data, Payload):
  91. return data
  92. for factory, type in _CHAIN(self._first, self._normal, self._last):
  93. if isinstance(data, type):
  94. return factory(data, *args, **kwargs)
  95. raise LookupError()
  96. def register(
  97. self, factory: PayloadType, type: Any, *, order: Order = Order.normal
  98. ) -> None:
  99. if order is Order.try_first:
  100. self._first.append((factory, type))
  101. elif order is Order.normal:
  102. self._normal.append((factory, type))
  103. elif order is Order.try_last:
  104. self._last.append((factory, type))
  105. else:
  106. raise ValueError(f"Unsupported order {order!r}")
  107. class Payload(ABC):
  108. _default_content_type: str = "application/octet-stream"
  109. _size: Optional[int] = None
  110. def __init__(
  111. self,
  112. value: Any,
  113. headers: Optional[
  114. Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
  115. ] = None,
  116. content_type: Union[str, None, _SENTINEL] = sentinel,
  117. filename: Optional[str] = None,
  118. encoding: Optional[str] = None,
  119. **kwargs: Any,
  120. ) -> None:
  121. self._encoding = encoding
  122. self._filename = filename
  123. self._headers: _CIMultiDict = CIMultiDict()
  124. self._value = value
  125. if content_type is not sentinel and content_type is not None:
  126. self._headers[hdrs.CONTENT_TYPE] = content_type
  127. elif self._filename is not None:
  128. content_type = mimetypes.guess_type(self._filename)[0]
  129. if content_type is None:
  130. content_type = self._default_content_type
  131. self._headers[hdrs.CONTENT_TYPE] = content_type
  132. else:
  133. self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
  134. self._headers.update(headers or {})
  135. @property
  136. def size(self) -> Optional[int]:
  137. """Size of the payload."""
  138. return self._size
  139. @property
  140. def filename(self) -> Optional[str]:
  141. """Filename of the payload."""
  142. return self._filename
  143. @property
  144. def headers(self) -> _CIMultiDict:
  145. """Custom item headers"""
  146. return self._headers
  147. @property
  148. def _binary_headers(self) -> bytes:
  149. return (
  150. "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
  151. "utf-8"
  152. )
  153. + b"\r\n"
  154. )
  155. @property
  156. def encoding(self) -> Optional[str]:
  157. """Payload encoding"""
  158. return self._encoding
  159. @property
  160. def content_type(self) -> str:
  161. """Content type"""
  162. return self._headers[hdrs.CONTENT_TYPE]
  163. def set_content_disposition(
  164. self,
  165. disptype: str,
  166. quote_fields: bool = True,
  167. _charset: str = "utf-8",
  168. **params: Any,
  169. ) -> None:
  170. """Sets ``Content-Disposition`` header."""
  171. self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
  172. disptype, quote_fields=quote_fields, _charset=_charset, **params
  173. )
  174. @abstractmethod
  175. async def write(self, writer: AbstractStreamWriter) -> None:
  176. """Write payload.
  177. writer is an AbstractStreamWriter instance:
  178. """
  179. class BytesPayload(Payload):
  180. def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
  181. if not isinstance(value, (bytes, bytearray, memoryview)):
  182. raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")
  183. if "content_type" not in kwargs:
  184. kwargs["content_type"] = "application/octet-stream"
  185. super().__init__(value, *args, **kwargs)
  186. if isinstance(value, memoryview):
  187. self._size = value.nbytes
  188. else:
  189. self._size = len(value)
  190. if self._size > TOO_LARGE_BYTES_BODY:
  191. kwargs = {"source": self}
  192. warnings.warn(
  193. "Sending a large body directly with raw bytes might"
  194. " lock the event loop. You should probably pass an "
  195. "io.BytesIO object instead",
  196. ResourceWarning,
  197. **kwargs,
  198. )
  199. async def write(self, writer: AbstractStreamWriter) -> None:
  200. await writer.write(self._value)
  201. class StringPayload(BytesPayload):
  202. def __init__(
  203. self,
  204. value: str,
  205. *args: Any,
  206. encoding: Optional[str] = None,
  207. content_type: Optional[str] = None,
  208. **kwargs: Any,
  209. ) -> None:
  210. if encoding is None:
  211. if content_type is None:
  212. real_encoding = "utf-8"
  213. content_type = "text/plain; charset=utf-8"
  214. else:
  215. mimetype = parse_mimetype(content_type)
  216. real_encoding = mimetype.parameters.get("charset", "utf-8")
  217. else:
  218. if content_type is None:
  219. content_type = "text/plain; charset=%s" % encoding
  220. real_encoding = encoding
  221. super().__init__(
  222. value.encode(real_encoding),
  223. encoding=real_encoding,
  224. content_type=content_type,
  225. *args,
  226. **kwargs,
  227. )
  228. class StringIOPayload(StringPayload):
  229. def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
  230. super().__init__(value.read(), *args, **kwargs)
  231. class IOBasePayload(Payload):
  232. _value: IO[Any]
  233. def __init__(
  234. self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
  235. ) -> None:
  236. if "filename" not in kwargs:
  237. kwargs["filename"] = guess_filename(value)
  238. super().__init__(value, *args, **kwargs)
  239. if self._filename is not None and disposition is not None:
  240. if hdrs.CONTENT_DISPOSITION not in self.headers:
  241. self.set_content_disposition(disposition, filename=self._filename)
  242. async def write(self, writer: AbstractStreamWriter) -> None:
  243. loop = asyncio.get_event_loop()
  244. try:
  245. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  246. while chunk:
  247. await writer.write(chunk)
  248. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  249. finally:
  250. await loop.run_in_executor(None, self._value.close)
  251. class TextIOPayload(IOBasePayload):
  252. _value: TextIO
  253. def __init__(
  254. self,
  255. value: TextIO,
  256. *args: Any,
  257. encoding: Optional[str] = None,
  258. content_type: Optional[str] = None,
  259. **kwargs: Any,
  260. ) -> None:
  261. if encoding is None:
  262. if content_type is None:
  263. encoding = "utf-8"
  264. content_type = "text/plain; charset=utf-8"
  265. else:
  266. mimetype = parse_mimetype(content_type)
  267. encoding = mimetype.parameters.get("charset", "utf-8")
  268. else:
  269. if content_type is None:
  270. content_type = "text/plain; charset=%s" % encoding
  271. super().__init__(
  272. value,
  273. content_type=content_type,
  274. encoding=encoding,
  275. *args,
  276. **kwargs,
  277. )
  278. @property
  279. def size(self) -> Optional[int]:
  280. try:
  281. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  282. except OSError:
  283. return None
  284. async def write(self, writer: AbstractStreamWriter) -> None:
  285. loop = asyncio.get_event_loop()
  286. try:
  287. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  288. while chunk:
  289. data = (
  290. chunk.encode(encoding=self._encoding)
  291. if self._encoding
  292. else chunk.encode()
  293. )
  294. await writer.write(data)
  295. chunk = await loop.run_in_executor(None, self._value.read, 2**16)
  296. finally:
  297. await loop.run_in_executor(None, self._value.close)
  298. class BytesIOPayload(IOBasePayload):
  299. @property
  300. def size(self) -> int:
  301. position = self._value.tell()
  302. end = self._value.seek(0, os.SEEK_END)
  303. self._value.seek(position)
  304. return end - position
  305. class BufferedReaderPayload(IOBasePayload):
  306. @property
  307. def size(self) -> Optional[int]:
  308. try:
  309. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  310. except OSError:
  311. # data.fileno() is not supported, e.g.
  312. # io.BufferedReader(io.BytesIO(b'data'))
  313. return None
  314. class JsonPayload(BytesPayload):
  315. def __init__(
  316. self,
  317. value: Any,
  318. encoding: str = "utf-8",
  319. content_type: str = "application/json",
  320. dumps: JSONEncoder = json.dumps,
  321. *args: Any,
  322. **kwargs: Any,
  323. ) -> None:
  324. super().__init__(
  325. dumps(value).encode(encoding),
  326. content_type=content_type,
  327. encoding=encoding,
  328. *args,
  329. **kwargs,
  330. )
  331. if TYPE_CHECKING:
  332. from typing import AsyncIterable, AsyncIterator
  333. _AsyncIterator = AsyncIterator[bytes]
  334. _AsyncIterable = AsyncIterable[bytes]
  335. else:
  336. from collections.abc import AsyncIterable, AsyncIterator
  337. _AsyncIterator = AsyncIterator
  338. _AsyncIterable = AsyncIterable
  339. class AsyncIterablePayload(Payload):
  340. _iter: Optional[_AsyncIterator] = None
  341. def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
  342. if not isinstance(value, AsyncIterable):
  343. raise TypeError(
  344. "value argument must support "
  345. "collections.abc.AsyncIterable interface, "
  346. "got {!r}".format(type(value))
  347. )
  348. if "content_type" not in kwargs:
  349. kwargs["content_type"] = "application/octet-stream"
  350. super().__init__(value, *args, **kwargs)
  351. self._iter = value.__aiter__()
  352. async def write(self, writer: AbstractStreamWriter) -> None:
  353. if self._iter:
  354. try:
  355. # iter is not None check prevents rare cases
  356. # when the case iterable is used twice
  357. while True:
  358. chunk = await self._iter.__anext__()
  359. await writer.write(chunk)
  360. except StopAsyncIteration:
  361. self._iter = None
  362. class StreamReaderPayload(AsyncIterablePayload):
  363. def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
  364. super().__init__(value.iter_any(), *args, **kwargs)
  365. PAYLOAD_REGISTRY = PayloadRegistry()
  366. PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
  367. PAYLOAD_REGISTRY.register(StringPayload, str)
  368. PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
  369. PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
  370. PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
  371. PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
  372. PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
  373. PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
  374. # try_last for giving a chance to more specialized async interables like
  375. # multidict.BodyPartReaderPayload override the default
  376. PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)