_multipart.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import io
  2. import os
  3. import typing
  4. from pathlib import Path
  5. from ._types import (
  6. AsyncByteStream,
  7. FileContent,
  8. FileTypes,
  9. RequestData,
  10. RequestFiles,
  11. SyncByteStream,
  12. )
  13. from ._utils import (
  14. format_form_param,
  15. guess_content_type,
  16. peek_filelike_length,
  17. primitive_value_to_str,
  18. to_bytes,
  19. )
  20. def get_multipart_boundary_from_content_type(
  21. content_type: typing.Optional[bytes],
  22. ) -> typing.Optional[bytes]:
  23. if not content_type or not content_type.startswith(b"multipart/form-data"):
  24. return None
  25. # parse boundary according to
  26. # https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1
  27. if b";" in content_type:
  28. for section in content_type.split(b";"):
  29. if section.strip().lower().startswith(b"boundary="):
  30. return section.strip()[len(b"boundary=") :].strip(b'"')
  31. return None
  32. class DataField:
  33. """
  34. A single form field item, within a multipart form field.
  35. """
  36. def __init__(
  37. self, name: str, value: typing.Union[str, bytes, int, float, None]
  38. ) -> None:
  39. if not isinstance(name, str):
  40. raise TypeError(
  41. f"Invalid type for name. Expected str, got {type(name)}: {name!r}"
  42. )
  43. if value is not None and not isinstance(value, (str, bytes, int, float)):
  44. raise TypeError(
  45. f"Invalid type for value. Expected primitive type, got {type(value)}: {value!r}"
  46. )
  47. self.name = name
  48. self.value: typing.Union[str, bytes] = (
  49. value if isinstance(value, bytes) else primitive_value_to_str(value)
  50. )
  51. def render_headers(self) -> bytes:
  52. if not hasattr(self, "_headers"):
  53. name = format_form_param("name", self.name)
  54. self._headers = b"".join(
  55. [b"Content-Disposition: form-data; ", name, b"\r\n\r\n"]
  56. )
  57. return self._headers
  58. def render_data(self) -> bytes:
  59. if not hasattr(self, "_data"):
  60. self._data = to_bytes(self.value)
  61. return self._data
  62. def get_length(self) -> int:
  63. headers = self.render_headers()
  64. data = self.render_data()
  65. return len(headers) + len(data)
  66. def render(self) -> typing.Iterator[bytes]:
  67. yield self.render_headers()
  68. yield self.render_data()
  69. class FileField:
  70. """
  71. A single file field item, within a multipart form field.
  72. """
  73. CHUNK_SIZE = 64 * 1024
  74. def __init__(self, name: str, value: FileTypes) -> None:
  75. self.name = name
  76. fileobj: FileContent
  77. headers: typing.Dict[str, str] = {}
  78. content_type: typing.Optional[str] = None
  79. # This large tuple based API largely mirror's requests' API
  80. # It would be good to think of better APIs for this that we could include in httpx 2.0
  81. # since variable length tuples (especially of 4 elements) are quite unwieldly
  82. if isinstance(value, tuple):
  83. if len(value) == 2:
  84. # neither the 3rd parameter (content_type) nor the 4th (headers) was included
  85. filename, fileobj = value # type: ignore
  86. elif len(value) == 3:
  87. filename, fileobj, content_type = value # type: ignore
  88. else:
  89. # all 4 parameters included
  90. filename, fileobj, content_type, headers = value # type: ignore
  91. else:
  92. filename = Path(str(getattr(value, "name", "upload"))).name
  93. fileobj = value
  94. if content_type is None:
  95. content_type = guess_content_type(filename)
  96. has_content_type_header = any("content-type" in key.lower() for key in headers)
  97. if content_type is not None and not has_content_type_header:
  98. # note that unlike requests, we ignore the content_type
  99. # provided in the 3rd tuple element if it is also included in the headers
  100. # requests does the opposite (it overwrites the header with the 3rd tuple element)
  101. headers["Content-Type"] = content_type
  102. if isinstance(fileobj, io.StringIO):
  103. raise TypeError(
  104. "Multipart file uploads require 'io.BytesIO', not 'io.StringIO'."
  105. )
  106. if isinstance(fileobj, io.TextIOBase):
  107. raise TypeError(
  108. "Multipart file uploads must be opened in binary mode, not text mode."
  109. )
  110. self.filename = filename
  111. self.file = fileobj
  112. self.headers = headers
  113. def get_length(self) -> typing.Optional[int]:
  114. headers = self.render_headers()
  115. if isinstance(self.file, (str, bytes)):
  116. return len(headers) + len(to_bytes(self.file))
  117. file_length = peek_filelike_length(self.file)
  118. # If we can't determine the filesize without reading it into memory,
  119. # then return `None` here, to indicate an unknown file length.
  120. if file_length is None:
  121. return None
  122. return len(headers) + file_length
  123. def render_headers(self) -> bytes:
  124. if not hasattr(self, "_headers"):
  125. parts = [
  126. b"Content-Disposition: form-data; ",
  127. format_form_param("name", self.name),
  128. ]
  129. if self.filename:
  130. filename = format_form_param("filename", self.filename)
  131. parts.extend([b"; ", filename])
  132. for header_name, header_value in self.headers.items():
  133. key, val = f"\r\n{header_name}: ".encode(), header_value.encode()
  134. parts.extend([key, val])
  135. parts.append(b"\r\n\r\n")
  136. self._headers = b"".join(parts)
  137. return self._headers
  138. def render_data(self) -> typing.Iterator[bytes]:
  139. if isinstance(self.file, (str, bytes)):
  140. yield to_bytes(self.file)
  141. return
  142. if hasattr(self.file, "seek"):
  143. try:
  144. self.file.seek(0)
  145. except io.UnsupportedOperation:
  146. pass
  147. chunk = self.file.read(self.CHUNK_SIZE)
  148. while chunk:
  149. yield to_bytes(chunk)
  150. chunk = self.file.read(self.CHUNK_SIZE)
  151. def render(self) -> typing.Iterator[bytes]:
  152. yield self.render_headers()
  153. yield from self.render_data()
  154. class MultipartStream(SyncByteStream, AsyncByteStream):
  155. """
  156. Request content as streaming multipart encoded form data.
  157. """
  158. def __init__(
  159. self,
  160. data: RequestData,
  161. files: RequestFiles,
  162. boundary: typing.Optional[bytes] = None,
  163. ) -> None:
  164. if boundary is None:
  165. boundary = os.urandom(16).hex().encode("ascii")
  166. self.boundary = boundary
  167. self.content_type = "multipart/form-data; boundary=%s" % boundary.decode(
  168. "ascii"
  169. )
  170. self.fields = list(self._iter_fields(data, files))
  171. def _iter_fields(
  172. self, data: RequestData, files: RequestFiles
  173. ) -> typing.Iterator[typing.Union[FileField, DataField]]:
  174. for name, value in data.items():
  175. if isinstance(value, (tuple, list)):
  176. for item in value:
  177. yield DataField(name=name, value=item)
  178. else:
  179. yield DataField(name=name, value=value)
  180. file_items = files.items() if isinstance(files, typing.Mapping) else files
  181. for name, value in file_items:
  182. yield FileField(name=name, value=value)
  183. def iter_chunks(self) -> typing.Iterator[bytes]:
  184. for field in self.fields:
  185. yield b"--%s\r\n" % self.boundary
  186. yield from field.render()
  187. yield b"\r\n"
  188. yield b"--%s--\r\n" % self.boundary
  189. def get_content_length(self) -> typing.Optional[int]:
  190. """
  191. Return the length of the multipart encoded content, or `None` if
  192. any of the files have a length that cannot be determined upfront.
  193. """
  194. boundary_length = len(self.boundary)
  195. length = 0
  196. for field in self.fields:
  197. field_length = field.get_length()
  198. if field_length is None:
  199. return None
  200. length += 2 + boundary_length + 2 # b"--{boundary}\r\n"
  201. length += field_length
  202. length += 2 # b"\r\n"
  203. length += 2 + boundary_length + 4 # b"--{boundary}--\r\n"
  204. return length
  205. # Content stream interface.
  206. def get_headers(self) -> typing.Dict[str, str]:
  207. content_length = self.get_content_length()
  208. content_type = self.content_type
  209. if content_length is None:
  210. return {"Transfer-Encoding": "chunked", "Content-Type": content_type}
  211. return {"Content-Length": str(content_length), "Content-Type": content_type}
  212. def __iter__(self) -> typing.Iterator[bytes]:
  213. for chunk in self.iter_chunks():
  214. yield chunk
  215. async def __aiter__(self) -> typing.AsyncIterator[bytes]:
  216. for chunk in self.iter_chunks():
  217. yield chunk