main.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. #!/usr/bin/env python
  2. """Protoc Plugin to generate mypy stubs."""
  3. from __future__ import annotations
  4. import sys
  5. from collections import defaultdict
  6. from contextlib import contextmanager
  7. from typing import (
  8. Any,
  9. Callable,
  10. Dict,
  11. Iterable,
  12. Iterator,
  13. List,
  14. Optional,
  15. Set,
  16. Sequence,
  17. Tuple,
  18. )
  19. import google.protobuf.descriptor_pb2 as d
  20. from google.protobuf.compiler import plugin_pb2 as plugin_pb2
  21. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  22. from google.protobuf.internal.well_known_types import WKTBASES
  23. from . import extensions_pb2
  24. __version__ = "3.6.0"
  25. # SourceCodeLocation is defined by `message Location` here
  26. # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
  27. SourceCodeLocation = List[int]
  28. # So phabricator doesn't think mypy_protobuf.py is generated
  29. GENERATED = "@ge" + "nerated"
  30. HEADER = f"""
  31. {GENERATED} by mypy-protobuf. Do not edit manually!
  32. isort:skip_file
  33. """
  34. # See https://github.com/nipunn1313/mypy-protobuf/issues/73 for details
  35. PYTHON_RESERVED = {
  36. "False",
  37. "None",
  38. "True",
  39. "and",
  40. "as",
  41. "async",
  42. "await",
  43. "assert",
  44. "break",
  45. "class",
  46. "continue",
  47. "def",
  48. "del",
  49. "elif",
  50. "else",
  51. "except",
  52. "finally",
  53. "for",
  54. "from",
  55. "global",
  56. "if",
  57. "import",
  58. "in",
  59. "is",
  60. "lambda",
  61. "nonlocal",
  62. "not",
  63. "or",
  64. "pass",
  65. "raise",
  66. "return",
  67. "try",
  68. "while",
  69. "with",
  70. "yield",
  71. }
  72. PROTO_ENUM_RESERVED = {
  73. "Name",
  74. "Value",
  75. "keys",
  76. "values",
  77. "items",
  78. }
  79. def _mangle_global_identifier(name: str) -> str:
  80. """
  81. Module level identifiers are mangled and aliased so that they can be disambiguated
  82. from fields/enum variants with the same name within the file.
  83. Eg:
  84. Enum variant `Name` or message field `Name` might conflict with a top level
  85. message or enum named `Name`, so mangle it with a global___ prefix for
  86. internal references. Note that this doesn't affect inner enums/messages
  87. because they get fuly qualified when referenced within a file"""
  88. return f"global___{name}"
  89. class Descriptors(object):
  90. def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
  91. files = {f.name: f for f in request.proto_file}
  92. to_generate = {n: files[n] for n in request.file_to_generate}
  93. self.files: Dict[str, d.FileDescriptorProto] = files
  94. self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
  95. self.messages: Dict[str, d.DescriptorProto] = {}
  96. self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
  97. def _add_enums(
  98. enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
  99. prefix: str,
  100. _fd: d.FileDescriptorProto,
  101. ) -> None:
  102. for enum in enums:
  103. self.message_to_fd[prefix + enum.name] = _fd
  104. self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
  105. def _add_messages(
  106. messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
  107. prefix: str,
  108. _fd: d.FileDescriptorProto,
  109. ) -> None:
  110. for message in messages:
  111. self.messages[prefix + message.name] = message
  112. self.message_to_fd[prefix + message.name] = _fd
  113. sub_prefix = prefix + message.name + "."
  114. _add_messages(message.nested_type, sub_prefix, _fd)
  115. _add_enums(message.enum_type, sub_prefix, _fd)
  116. for fd in request.proto_file:
  117. start_prefix = "." + fd.package + "." if fd.package else "."
  118. _add_messages(fd.message_type, start_prefix, fd)
  119. _add_enums(fd.enum_type, start_prefix, fd)
  120. class PkgWriter(object):
  121. """Writes a single pyi file"""
  122. def __init__(
  123. self,
  124. fd: d.FileDescriptorProto,
  125. descriptors: Descriptors,
  126. readable_stubs: bool,
  127. relax_strict_optional_primitives: bool,
  128. grpc: bool,
  129. ) -> None:
  130. self.fd = fd
  131. self.descriptors = descriptors
  132. self.readable_stubs = readable_stubs
  133. self.relax_strict_optional_primitives = relax_strict_optional_primitives
  134. self.grpc = grpc
  135. self.lines: List[str] = []
  136. self.indent = ""
  137. # Set of {x}, where {x} corresponds to to `import {x}`
  138. self.imports: Set[str] = set()
  139. # dictionary of x->(y,z) for `from {x} import {y} as {z}`
  140. # if {z} is None, then it shortens to `from {x} import {y}`
  141. self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set)
  142. self.typing_extensions_min: Optional[Tuple[int, int]] = None
  143. # Comments
  144. self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location}
  145. def _import(self, path: str, name: str) -> str:
  146. """Imports a stdlib path and returns a handle to it
  147. eg. self._import("typing", "Literal") -> "Literal"
  148. """
  149. if path == "typing_extensions":
  150. stabilization = {
  151. "TypeAlias": (3, 10),
  152. }
  153. assert name in stabilization
  154. if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
  155. self.typing_extensions_min = stabilization[name]
  156. return "typing_extensions." + name
  157. imp = path.replace("/", ".")
  158. if self.readable_stubs:
  159. self.from_imports[imp].add((name, None))
  160. return name
  161. else:
  162. self.imports.add(imp)
  163. return imp + "." + name
  164. def _import_message(self, name: str) -> str:
  165. """Import a referenced message and return a handle"""
  166. message_fd = self.descriptors.message_to_fd[name]
  167. assert message_fd.name.endswith(".proto")
  168. # Strip off package name
  169. if message_fd.package:
  170. assert name.startswith("." + message_fd.package + ".")
  171. name = name[len("." + message_fd.package + ".") :]
  172. else:
  173. assert name.startswith(".")
  174. name = name[1:]
  175. # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
  176. split = name.split(".")
  177. for i, part in enumerate(split):
  178. if part in PYTHON_RESERVED:
  179. split[i] = "_r_" + part
  180. name = ".".join(split)
  181. # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
  182. if not self.grpc and message_fd.name == self.fd.name:
  183. return name if self.readable_stubs else _mangle_global_identifier(name)
  184. # Not in file. Must import
  185. # Python generated code ignores proto packages, so the only relevant factor is
  186. # whether it is in the file or not.
  187. import_name = self._import(message_fd.name[:-6].replace("-", "_") + "_pb2", split[0])
  188. remains = ".".join(split[1:])
  189. if not remains:
  190. return import_name
  191. # remains could either be a direct import of a nested enum or message
  192. # from another package.
  193. return import_name + "." + remains
  194. def _builtin(self, name: str) -> str:
  195. return self._import("builtins", name)
  196. @contextmanager
  197. def _indent(self) -> Iterator[None]:
  198. self.indent = self.indent + " "
  199. yield
  200. self.indent = self.indent[:-4]
  201. def _write_line(self, line: str, *args: Any) -> None:
  202. if args:
  203. line = line.format(*args)
  204. if line == "":
  205. self.lines.append(line)
  206. else:
  207. self.lines.append(self.indent + line)
  208. def _break_text(self, text_block: str) -> List[str]:
  209. if text_block == "":
  210. return []
  211. return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")]
  212. def _has_comments(self, scl: SourceCodeLocation) -> bool:
  213. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  214. return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments)
  215. def _write_comments(self, scl: SourceCodeLocation) -> bool:
  216. """Return true if any comments were written"""
  217. if not self._has_comments(scl):
  218. return False
  219. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  220. assert sci_loc is not None
  221. leading_detached_lines = []
  222. leading_lines = []
  223. trailing_lines = []
  224. for leading_detached_comment in sci_loc.leading_detached_comments:
  225. leading_detached_lines = self._break_text(leading_detached_comment)
  226. if sci_loc.leading_comments is not None:
  227. leading_lines = self._break_text(sci_loc.leading_comments)
  228. # Trailing comments also go in the header - to make sure it gets into the docstring
  229. if sci_loc.trailing_comments is not None:
  230. trailing_lines = self._break_text(sci_loc.trailing_comments)
  231. lines = leading_detached_lines
  232. if leading_detached_lines and (leading_lines or trailing_lines):
  233. lines.append("")
  234. lines.extend(leading_lines)
  235. lines.extend(trailing_lines)
  236. lines = [
  237. # Escape triple-quotes that would otherwise end the docstring early.
  238. line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
  239. for line in lines
  240. ]
  241. if len(lines) == 1:
  242. line = lines[0]
  243. if line.endswith(('"', "\\")):
  244. # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
  245. # insert some whitespace to separate it from the closing quotes.
  246. # This is not necessary with multiline comments
  247. # because in that case we always insert a newline before the trailing triple-quotes.
  248. line = line + " "
  249. self._write_line(f'"""{line}"""')
  250. else:
  251. for i, line in enumerate(lines):
  252. if i == 0:
  253. self._write_line(f'"""{line}')
  254. else:
  255. self._write_line(f"{line}")
  256. self._write_line('"""')
  257. return True
  258. def write_enum_values(
  259. self,
  260. values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
  261. value_type: str,
  262. scl_prefix: SourceCodeLocation,
  263. ) -> None:
  264. for i, val in values:
  265. if val.name in PYTHON_RESERVED:
  266. continue
  267. scl = scl_prefix + [i]
  268. self._write_line(
  269. f"{val.name}: {value_type} # {val.number}",
  270. )
  271. self._write_comments(scl)
  272. def write_module_attributes(self) -> None:
  273. wl = self._write_line
  274. fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
  275. wl(f"DESCRIPTOR: {fd_type}")
  276. wl("")
  277. def write_enums(
  278. self,
  279. enums: Iterable[d.EnumDescriptorProto],
  280. prefix: str,
  281. scl_prefix: SourceCodeLocation,
  282. ) -> None:
  283. wl = self._write_line
  284. for i, enum in enumerate(enums):
  285. class_name = enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
  286. value_type_fq = prefix + class_name + ".ValueType"
  287. enum_helper_class = "_" + enum.name
  288. value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
  289. etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
  290. scl = scl_prefix + [i]
  291. wl(f"class {enum_helper_class}:")
  292. with self._indent():
  293. wl(
  294. 'ValueType = {}("ValueType", {})',
  295. self._import("typing", "NewType"),
  296. self._builtin("int"),
  297. )
  298. # Alias to the classic shorter definition "V"
  299. wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
  300. wl("")
  301. wl(
  302. "class {}({}[{}], {}):",
  303. etw_helper_class,
  304. self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
  305. value_type_helper_fq,
  306. self._builtin("type"),
  307. )
  308. with self._indent():
  309. ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
  310. wl(f"DESCRIPTOR: {ed}")
  311. self.write_enum_values(
  312. [(i, v) for i, v in enumerate(enum.value) if v.name not in PROTO_ENUM_RESERVED],
  313. value_type_helper_fq,
  314. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  315. )
  316. wl("")
  317. if self._has_comments(scl):
  318. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
  319. with self._indent():
  320. self._write_comments(scl)
  321. wl("")
  322. else:
  323. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...")
  324. if prefix == "":
  325. wl("")
  326. self.write_enum_values(
  327. enumerate(enum.value),
  328. value_type_fq,
  329. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  330. )
  331. if prefix == "" and not self.readable_stubs:
  332. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  333. wl("")
  334. def write_messages(
  335. self,
  336. messages: Iterable[d.DescriptorProto],
  337. prefix: str,
  338. scl_prefix: SourceCodeLocation,
  339. ) -> None:
  340. wl = self._write_line
  341. for i, desc in enumerate(messages):
  342. qualified_name = prefix + desc.name
  343. # Reproduce some hardcoded logic from the protobuf implementation - where
  344. # some specific "well_known_types" generated protos to have additional
  345. # base classes
  346. addl_base = ""
  347. if self.fd.package + "." + desc.name in WKTBASES:
  348. # chop off the .proto - and import the well known type
  349. # eg `from google.protobuf.duration import Duration`
  350. well_known_type = WKTBASES[self.fd.package + "." + desc.name]
  351. addl_base = ", " + self._import(
  352. "google.protobuf.internal.well_known_types",
  353. well_known_type.__name__,
  354. )
  355. class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
  356. message_class = self._import("google.protobuf.message", "Message")
  357. wl("@{}", self._import("typing", "final"))
  358. wl(f"class {class_name}({message_class}{addl_base}):")
  359. with self._indent():
  360. scl = scl_prefix + [i]
  361. if self._write_comments(scl):
  362. wl("")
  363. desc_type = self._import("google.protobuf.descriptor", "Descriptor")
  364. wl(f"DESCRIPTOR: {desc_type}")
  365. wl("")
  366. # Nested enums/messages
  367. self.write_enums(
  368. desc.enum_type,
  369. qualified_name + ".",
  370. scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
  371. )
  372. self.write_messages(
  373. desc.nested_type,
  374. qualified_name + ".",
  375. scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
  376. )
  377. # integer constants for field numbers
  378. for f in desc.field:
  379. wl(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  380. for idx, field in enumerate(desc.field):
  381. if field.name in PYTHON_RESERVED:
  382. continue
  383. field_type = self.python_type(field)
  384. if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  385. # Scalar non repeated fields are r/w
  386. wl(f"{field.name}: {field_type}")
  387. self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx])
  388. for idx, field in enumerate(desc.field):
  389. if field.name in PYTHON_RESERVED:
  390. continue
  391. field_type = self.python_type(field)
  392. if not (is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED):
  393. # r/o Getters for non-scalar fields and scalar-repeated fields
  394. scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
  395. wl("@property")
  396. body = " ..." if not self._has_comments(scl_field) else ""
  397. wl(f"def {field.name}(self) -> {field_type}:{body}")
  398. if self._has_comments(scl_field):
  399. with self._indent():
  400. self._write_comments(scl_field)
  401. wl("")
  402. self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER])
  403. # Constructor
  404. wl("def __init__(")
  405. with self._indent():
  406. if any(f.name == "self" for f in desc.field):
  407. wl("self_, # pyright: ignore[reportSelfClsParameterName]")
  408. else:
  409. wl("self,")
  410. with self._indent():
  411. constructor_fields = [f for f in desc.field if f.name not in PYTHON_RESERVED]
  412. if len(constructor_fields) > 0:
  413. # Only positional args allowed
  414. # See https://github.com/nipunn1313/mypy-protobuf/issues/71
  415. wl("*,")
  416. for field in constructor_fields:
  417. field_type = self.python_type(field, generic_container=True)
  418. if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional:
  419. wl(f"{field.name}: {field_type} = ...,")
  420. else:
  421. wl(f"{field.name}: {field_type} | None = ...,")
  422. wl(") -> None: ...")
  423. self.write_stringly_typed_fields(desc)
  424. if prefix == "" and not self.readable_stubs:
  425. wl("")
  426. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  427. wl("")
  428. def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
  429. """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
  430. wl = self._write_line
  431. # HasField, ClearField, WhichOneof accepts both bytes/str
  432. # HasField only supports singular. ClearField supports repeated as well
  433. # In proto3, HasField only supports message fields and optional fields
  434. # HasField always supports oneof fields
  435. hf_fields = [f.name for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))]
  436. cf_fields = [f.name for f in desc.field]
  437. wo_fields = {oneof.name: [f.name for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)}
  438. hf_fields.extend(wo_fields.keys())
  439. cf_fields.extend(wo_fields.keys())
  440. hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields))
  441. cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields))
  442. if not hf_fields and not cf_fields and not wo_fields:
  443. return
  444. if hf_fields:
  445. wl(
  446. "def HasField(self, field_name: {}[{}]) -> {}: ...",
  447. self._import("typing", "Literal"),
  448. hf_fields_text,
  449. self._builtin("bool"),
  450. )
  451. if cf_fields:
  452. wl(
  453. "def ClearField(self, field_name: {}[{}]) -> None: ...",
  454. self._import("typing", "Literal"),
  455. cf_fields_text,
  456. )
  457. for wo_field, members in sorted(wo_fields.items()):
  458. if len(wo_fields) > 1:
  459. wl("@{}", self._import("typing", "overload"))
  460. wl(
  461. "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...",
  462. self._import("typing", "Literal"),
  463. # Accepts both str and bytes
  464. f'"{wo_field}", b"{wo_field}"',
  465. self._import("typing", "Literal"),
  466. # Returns `str`
  467. ", ".join(f'"{m}"' for m in members),
  468. )
  469. def write_extensions(
  470. self,
  471. extensions: Sequence[d.FieldDescriptorProto],
  472. scl_prefix: SourceCodeLocation,
  473. ) -> None:
  474. wl = self._write_line
  475. for ext in extensions:
  476. wl(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  477. for i, ext in enumerate(extensions):
  478. scl = scl_prefix + [i]
  479. wl(
  480. "{}: {}[{}, {}]",
  481. ext.name,
  482. self._import(
  483. "google.protobuf.internal.extension_dict",
  484. "_ExtensionFieldDescriptor",
  485. ),
  486. self._import_message(ext.extendee),
  487. self.python_type(ext),
  488. )
  489. self._write_comments(scl)
  490. def write_methods(
  491. self,
  492. service: d.ServiceDescriptorProto,
  493. class_name: str,
  494. is_abstract: bool,
  495. scl_prefix: SourceCodeLocation,
  496. ) -> None:
  497. wl = self._write_line
  498. wl(
  499. "DESCRIPTOR: {}",
  500. self._import("google.protobuf.descriptor", "ServiceDescriptor"),
  501. )
  502. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  503. if not methods:
  504. wl("...")
  505. for i, method in methods:
  506. if is_abstract:
  507. wl("@{}", self._import("abc", "abstractmethod"))
  508. wl(f"def {method.name}(")
  509. with self._indent():
  510. wl(f"inst: {class_name}, # pyright: ignore[reportSelfClsParameterName]")
  511. wl(
  512. "rpc_controller: {},",
  513. self._import("google.protobuf.service", "RpcController"),
  514. )
  515. wl("request: {},", self._import_message(method.input_type))
  516. wl(
  517. "callback: {}[[{}], None] | None{},",
  518. self._import("collections.abc", "Callable"),
  519. self._import_message(method.output_type),
  520. "" if is_abstract else " = ...",
  521. )
  522. scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  523. wl(
  524. ") -> {}[{}]:{}",
  525. self._import("concurrent.futures", "Future"),
  526. self._import_message(method.output_type),
  527. " ..." if not self._has_comments(scl_method) else "",
  528. )
  529. if self._has_comments(scl_method):
  530. with self._indent():
  531. if not self._write_comments(scl_method):
  532. wl("...")
  533. wl("")
  534. def write_services(
  535. self,
  536. services: Iterable[d.ServiceDescriptorProto],
  537. scl_prefix: SourceCodeLocation,
  538. ) -> None:
  539. wl = self._write_line
  540. for i, service in enumerate(services):
  541. scl = scl_prefix + [i]
  542. class_name = service.name if service.name not in PYTHON_RESERVED else "_r_" + service.name
  543. # The service definition interface
  544. wl(
  545. "class {}({}, metaclass={}):",
  546. class_name,
  547. self._import("google.protobuf.service", "Service"),
  548. self._import("abc", "ABCMeta"),
  549. )
  550. with self._indent():
  551. if self._write_comments(scl):
  552. wl("")
  553. self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl)
  554. # The stub client
  555. stub_class_name = service.name + "_Stub"
  556. wl("class {}({}):", stub_class_name, class_name)
  557. with self._indent():
  558. if self._write_comments(scl):
  559. wl("")
  560. wl(
  561. "def __init__(self, rpc_channel: {}) -> None: ...",
  562. self._import("google.protobuf.service", "RpcChannel"),
  563. )
  564. self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl)
  565. def _import_casttype(self, casttype: str) -> str:
  566. split = casttype.split(".")
  567. assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
  568. pkg = split[0].replace("/", ".")
  569. return self._import(pkg, split[1])
  570. def _map_key_value_types(
  571. self,
  572. map_field: d.FieldDescriptorProto,
  573. key_field: d.FieldDescriptorProto,
  574. value_field: d.FieldDescriptorProto,
  575. ) -> Tuple[str, str]:
  576. oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype]
  577. if oldstyle_keytype:
  578. print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr)
  579. key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype
  580. ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field)
  581. oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype]
  582. if oldstyle_valuetype:
  583. print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr)
  584. value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype]
  585. vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field)
  586. return ktype, vtype
  587. def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str:
  588. module = "grpc.aio" if is_async else "grpc"
  589. if method.client_streaming:
  590. if method.server_streaming:
  591. return self._import(module, "StreamStreamMultiCallable")
  592. else:
  593. return self._import(module, "StreamUnaryMultiCallable")
  594. else:
  595. if method.server_streaming:
  596. return self._import(module, "UnaryStreamMultiCallable")
  597. else:
  598. return self._import(module, "UnaryUnaryMultiCallable")
  599. def _input_type(self, method: d.MethodDescriptorProto) -> str:
  600. result = self._import_message(method.input_type)
  601. return result
  602. def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str:
  603. result = self._import_message(method.input_type)
  604. if method.client_streaming:
  605. # See write_grpc_async_hacks().
  606. result = f"_MaybeAsyncIterator[{result}]"
  607. return result
  608. def _output_type(self, method: d.MethodDescriptorProto) -> str:
  609. result = self._import_message(method.output_type)
  610. return result
  611. def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str:
  612. result = self._import_message(method.output_type)
  613. if method.server_streaming:
  614. # Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp].
  615. # So both can be used in the covariant function return position.
  616. iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]"
  617. aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]"
  618. result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]"
  619. else:
  620. # Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp].
  621. # So both can be used in the covariant function return position.
  622. # Awaitable[Resp] is equivalent to async def.
  623. awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]"
  624. result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]"
  625. return result
  626. def write_grpc_async_hacks(self) -> None:
  627. wl = self._write_line
  628. # _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req].
  629. # So both can be used in the contravariant function parameter position.
  630. wl('_T = {}("_T")', self._import("typing", "TypeVar"))
  631. wl("")
  632. wl(
  633. "class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}): ...",
  634. self._import("collections.abc", "AsyncIterator"),
  635. self._import("collections.abc", "Iterator"),
  636. self._import("abc", "ABCMeta"),
  637. )
  638. wl("")
  639. # _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext
  640. # So both can be used in the contravariant function parameter position.
  641. wl(
  642. "class _ServicerContext({}, {}): # type: ignore[misc, type-arg]",
  643. self._import("grpc", "ServicerContext"),
  644. self._import("grpc.aio", "ServicerContext"),
  645. )
  646. with self._indent():
  647. wl("...")
  648. wl("")
  649. def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
  650. wl = self._write_line
  651. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  652. if not methods:
  653. wl("...")
  654. wl("")
  655. for i, method in methods:
  656. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  657. wl("@{}", self._import("abc", "abstractmethod"))
  658. wl("def {}(", method.name)
  659. with self._indent():
  660. wl("self,")
  661. input_name = "request_iterator" if method.client_streaming else "request"
  662. input_type = self._servicer_input_type(method)
  663. wl(f"{input_name}: {input_type},")
  664. wl("context: _ServicerContext,")
  665. wl(
  666. ") -> {}:{}",
  667. self._servicer_output_type(method),
  668. " ..." if not self._has_comments(scl) else "",
  669. )
  670. if self._has_comments(scl):
  671. with self._indent():
  672. if not self._write_comments(scl):
  673. wl("...")
  674. wl("")
  675. def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None:
  676. wl = self._write_line
  677. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  678. if not methods:
  679. wl("...")
  680. wl("")
  681. for i, method in methods:
  682. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  683. wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async))
  684. with self._indent():
  685. wl("{},", self._input_type(method))
  686. wl("{},", self._output_type(method))
  687. wl("]")
  688. self._write_comments(scl)
  689. wl("")
  690. def write_grpc_services(
  691. self,
  692. services: Iterable[d.ServiceDescriptorProto],
  693. scl_prefix: SourceCodeLocation,
  694. ) -> None:
  695. wl = self._write_line
  696. for i, service in enumerate(services):
  697. if service.name in PYTHON_RESERVED:
  698. continue
  699. scl = scl_prefix + [i]
  700. # The stub client
  701. wl(
  702. "class {}Stub:",
  703. service.name,
  704. )
  705. with self._indent():
  706. if self._write_comments(scl):
  707. wl("")
  708. # To support casting into FooAsyncStub, allow both Channel and aio.Channel here.
  709. channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]"
  710. wl("def __init__(self, channel: {}) -> None: ...", channel)
  711. self.write_grpc_stub_methods(service, scl)
  712. # The (fake) async stub client
  713. wl(
  714. "class {}AsyncStub:",
  715. service.name,
  716. )
  717. with self._indent():
  718. if self._write_comments(scl):
  719. wl("")
  720. # No __init__ since this isn't a real class (yet), and requires manual casting to work.
  721. self.write_grpc_stub_methods(service, scl, is_async=True)
  722. # The service definition interface
  723. wl(
  724. "class {}Servicer(metaclass={}):",
  725. service.name,
  726. self._import("abc", "ABCMeta"),
  727. )
  728. with self._indent():
  729. if self._write_comments(scl):
  730. wl("")
  731. self.write_grpc_methods(service, scl)
  732. server = self._import("grpc", "Server")
  733. aserver = self._import("grpc.aio", "Server")
  734. wl(
  735. "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
  736. service.name,
  737. service.name,
  738. f"{self._import('typing', 'Union')}[{server}, {aserver}]",
  739. )
  740. wl("")
  741. def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str:
  742. """
  743. generic_container
  744. if set, type the field with generic interfaces. Eg.
  745. - Iterable[int] rather than RepeatedScalarFieldContainer[int]
  746. - Mapping[k, v] rather than MessageMap[k, v]
  747. Can be useful for input types (eg constructor)
  748. """
  749. oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype]
  750. if oldstyle_casttype:
  751. print(f"Warning: Field {field.name}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr)
  752. casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype
  753. if casttype:
  754. return self._import_casttype(casttype)
  755. mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
  756. d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
  757. d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
  758. d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
  759. d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
  760. d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
  761. d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
  762. d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
  763. d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
  764. d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
  765. d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
  766. d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
  767. d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
  768. d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
  769. d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"),
  770. d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
  771. d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"),
  772. d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
  773. d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
  774. }
  775. assert field.type in mapping, "Unrecognized type: " + repr(field.type)
  776. field_type = mapping[field.type]()
  777. # For non-repeated fields, we're done!
  778. if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  779. return field_type
  780. # Scalar repeated fields go in RepeatedScalarFieldContainer
  781. if is_scalar(field):
  782. container = (
  783. self._import("collections.abc", "Iterable")
  784. if generic_container
  785. else self._import(
  786. "google.protobuf.internal.containers",
  787. "RepeatedScalarFieldContainer",
  788. )
  789. )
  790. return f"{container}[{field_type}]"
  791. # non-scalar repeated map fields go in ScalarMap/MessageMap
  792. msg = self.descriptors.messages[field.type_name]
  793. if msg.options.map_entry:
  794. # map generates a special Entry wrapper message
  795. if generic_container:
  796. container = self._import("collections.abc", "Mapping")
  797. elif is_scalar(msg.field[1]):
  798. container = self._import("google.protobuf.internal.containers", "ScalarMap")
  799. else:
  800. container = self._import("google.protobuf.internal.containers", "MessageMap")
  801. ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
  802. return f"{container}[{ktype}, {vtype}]"
  803. # non-scalar repetated fields go in RepeatedCompositeFieldContainer
  804. container = (
  805. self._import("collections.abc", "Iterable")
  806. if generic_container
  807. else self._import(
  808. "google.protobuf.internal.containers",
  809. "RepeatedCompositeFieldContainer",
  810. )
  811. )
  812. return f"{container}[{field_type}]"
  813. def write(self) -> str:
  814. # save current module content, so that imports and module docstring can be inserted
  815. saved_lines = self.lines
  816. self.lines = []
  817. # module docstring may exist as comment before syntax (optional) or package name
  818. if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]):
  819. self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER])
  820. if self.lines:
  821. assert self.lines[0].startswith('"""')
  822. self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}'
  823. self._write_line("")
  824. else:
  825. self._write_line(f'"""{HEADER}"""\n')
  826. for reexport_idx in self.fd.public_dependency:
  827. reexport_file = self.fd.dependency[reexport_idx]
  828. reexport_fd = self.descriptors.files[reexport_file]
  829. reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
  830. names = [m.name for m in reexport_fd.message_type] + [m.name for m in reexport_fd.enum_type] + [v.name for m in reexport_fd.enum_type for v in m.value] + [m.name for m in reexport_fd.extension]
  831. if reexport_fd.options.py_generic_services:
  832. names.extend(m.name for m in reexport_fd.service)
  833. if names:
  834. # n,n to force a reexport (from x import y as y)
  835. self.from_imports[reexport_imp].update((n, n) for n in names)
  836. if self.typing_extensions_min:
  837. self.imports.add("sys")
  838. for pkg in sorted(self.imports):
  839. self._write_line(f"import {pkg}")
  840. if self.typing_extensions_min:
  841. self._write_line("")
  842. self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:")
  843. self._write_line(" import typing as typing_extensions")
  844. self._write_line("else:")
  845. self._write_line(" import typing_extensions")
  846. for pkg, items in sorted(self.from_imports.items()):
  847. self._write_line(f"from {pkg} import (")
  848. for name, reexport_name in sorted(items):
  849. if reexport_name is None:
  850. self._write_line(f" {name},")
  851. else:
  852. self._write_line(f" {name} as {reexport_name},")
  853. self._write_line(")")
  854. self._write_line("")
  855. # restore module content
  856. self.lines += saved_lines
  857. content = "\n".join(self.lines)
  858. if not content.endswith("\n"):
  859. content = content + "\n"
  860. return content
  861. def is_scalar(fd: d.FieldDescriptorProto) -> bool:
  862. return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP)
  863. def generate_mypy_stubs(
  864. descriptors: Descriptors,
  865. response: plugin_pb2.CodeGeneratorResponse,
  866. quiet: bool,
  867. readable_stubs: bool,
  868. relax_strict_optional_primitives: bool,
  869. ) -> None:
  870. for name, fd in descriptors.to_generate.items():
  871. pkg_writer = PkgWriter(
  872. fd,
  873. descriptors,
  874. readable_stubs,
  875. relax_strict_optional_primitives,
  876. grpc=False,
  877. )
  878. pkg_writer.write_module_attributes()
  879. pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER])
  880. pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER])
  881. pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER])
  882. if fd.options.py_generic_services:
  883. pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  884. assert name == fd.name
  885. assert fd.name.endswith(".proto")
  886. output = response.file.add()
  887. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
  888. output.content = pkg_writer.write()
  889. def generate_mypy_grpc_stubs(
  890. descriptors: Descriptors,
  891. response: plugin_pb2.CodeGeneratorResponse,
  892. quiet: bool,
  893. readable_stubs: bool,
  894. relax_strict_optional_primitives: bool,
  895. ) -> None:
  896. for name, fd in descriptors.to_generate.items():
  897. pkg_writer = PkgWriter(
  898. fd,
  899. descriptors,
  900. readable_stubs,
  901. relax_strict_optional_primitives,
  902. grpc=True,
  903. )
  904. pkg_writer.write_grpc_async_hacks()
  905. pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  906. assert name == fd.name
  907. assert fd.name.endswith(".proto")
  908. output = response.file.add()
  909. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
  910. output.content = pkg_writer.write()
  911. @contextmanager
  912. def code_generation() -> Iterator[Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],]:
  913. if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
  914. print("mypy-protobuf " + __version__)
  915. sys.exit(0)
  916. # Read request message from stdin
  917. data = sys.stdin.buffer.read()
  918. # Parse request
  919. request = plugin_pb2.CodeGeneratorRequest()
  920. request.ParseFromString(data)
  921. # Create response
  922. response = plugin_pb2.CodeGeneratorResponse()
  923. # Declare support for optional proto3 fields
  924. response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
  925. yield request, response
  926. # Serialise response message
  927. output = response.SerializeToString()
  928. # Write to stdout
  929. sys.stdout.buffer.write(output)
  930. def main() -> None:
  931. # Generate mypy
  932. with code_generation() as (request, response):
  933. generate_mypy_stubs(
  934. Descriptors(request),
  935. response,
  936. "quiet" in request.parameter,
  937. "readable_stubs" in request.parameter,
  938. "relax_strict_optional_primitives" in request.parameter,
  939. )
  940. def grpc() -> None:
  941. # Generate grpc mypy
  942. with code_generation() as (request, response):
  943. generate_mypy_grpc_stubs(
  944. Descriptors(request),
  945. response,
  946. "quiet" in request.parameter,
  947. "readable_stubs" in request.parameter,
  948. "relax_strict_optional_primitives" in request.parameter,
  949. )
  950. if __name__ == "__main__":
  951. main()