main.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022
  1. #!/usr/bin/env python
  2. """Protoc Plugin to generate mypy stubs."""
  3. from __future__ import annotations
  4. import sys
  5. from collections import defaultdict
  6. from contextlib import contextmanager
  7. from typing import (
  8. Any,
  9. Callable,
  10. Dict,
  11. Iterable,
  12. Iterator,
  13. List,
  14. Optional,
  15. Set,
  16. Sequence,
  17. Tuple,
  18. )
  19. import google.protobuf.descriptor_pb2 as d
  20. from google.protobuf.compiler import plugin_pb2 as plugin_pb2
  21. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  22. from google.protobuf.internal.well_known_types import WKTBASES
  23. from . import extensions_pb2
  24. __version__ = "3.3.0"
  25. # SourceCodeLocation is defined by `message Location` here
  26. # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
  27. SourceCodeLocation = List[int]
  28. # So phabricator doesn't think mypy_protobuf.py is generated
  29. GENERATED = "@ge" + "nerated"
  30. HEADER = f"""
  31. {GENERATED} by mypy-protobuf. Do not edit manually!
  32. isort:skip_file
  33. """
  34. # See https://github.com/nipunn1313/mypy-protobuf/issues/73 for details
  35. PYTHON_RESERVED = {
  36. "False",
  37. "None",
  38. "True",
  39. "and",
  40. "as",
  41. "async",
  42. "await",
  43. "assert",
  44. "break",
  45. "class",
  46. "continue",
  47. "def",
  48. "del",
  49. "elif",
  50. "else",
  51. "except",
  52. "finally",
  53. "for",
  54. "from",
  55. "global",
  56. "if",
  57. "import",
  58. "in",
  59. "is",
  60. "lambda",
  61. "nonlocal",
  62. "not",
  63. "or",
  64. "pass",
  65. "raise",
  66. "return",
  67. "try",
  68. "while",
  69. "with",
  70. "yield",
  71. }
  72. PROTO_ENUM_RESERVED = {
  73. "Name",
  74. "Value",
  75. "keys",
  76. "values",
  77. "items",
  78. }
  79. def _mangle_global_identifier(name: str) -> str:
  80. """
  81. Module level identifiers are mangled and aliased so that they can be disambiguated
  82. from fields/enum variants with the same name within the file.
  83. Eg:
  84. Enum variant `Name` or message field `Name` might conflict with a top level
  85. message or enum named `Name`, so mangle it with a global___ prefix for
  86. internal references. Note that this doesn't affect inner enums/messages
  87. because they get fuly qualified when referenced within a file"""
  88. return f"global___{name}"
  89. class Descriptors(object):
  90. def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
  91. files = {f.name: f for f in request.proto_file}
  92. to_generate = {n: files[n] for n in request.file_to_generate}
  93. self.files: Dict[str, d.FileDescriptorProto] = files
  94. self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
  95. self.messages: Dict[str, d.DescriptorProto] = {}
  96. self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
  97. def _add_enums(
  98. enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
  99. prefix: str,
  100. _fd: d.FileDescriptorProto,
  101. ) -> None:
  102. for enum in enums:
  103. self.message_to_fd[prefix + enum.name] = _fd
  104. self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
  105. def _add_messages(
  106. messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
  107. prefix: str,
  108. _fd: d.FileDescriptorProto,
  109. ) -> None:
  110. for message in messages:
  111. self.messages[prefix + message.name] = message
  112. self.message_to_fd[prefix + message.name] = _fd
  113. sub_prefix = prefix + message.name + "."
  114. _add_messages(message.nested_type, sub_prefix, _fd)
  115. _add_enums(message.enum_type, sub_prefix, _fd)
  116. for fd in request.proto_file:
  117. start_prefix = "." + fd.package + "." if fd.package else "."
  118. _add_messages(fd.message_type, start_prefix, fd)
  119. _add_enums(fd.enum_type, start_prefix, fd)
  120. class PkgWriter(object):
  121. """Writes a single pyi file"""
  122. def __init__(
  123. self,
  124. fd: d.FileDescriptorProto,
  125. descriptors: Descriptors,
  126. readable_stubs: bool,
  127. relax_strict_optional_primitives: bool,
  128. grpc: bool,
  129. ) -> None:
  130. self.fd = fd
  131. self.descriptors = descriptors
  132. self.readable_stubs = readable_stubs
  133. self.relax_strict_optional_primitives = relax_strict_optional_primitives
  134. self.grpc = grpc
  135. self.lines: List[str] = []
  136. self.indent = ""
  137. # Set of {x}, where {x} corresponds to to `import {x}`
  138. self.imports: Set[str] = set()
  139. # dictionary of x->(y,z) for `from {x} import {y} as {z}`
  140. # if {z} is None, then it shortens to `from {x} import {y}`
  141. self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set)
  142. self.typing_extensions_min: Optional[Tuple[int, int]] = None
  143. # Comments
  144. self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location}
  145. def _import(self, path: str, name: str) -> str:
  146. """Imports a stdlib path and returns a handle to it
  147. eg. self._import("typing", "Literal") -> "Literal"
  148. """
  149. if path == "typing_extensions":
  150. stabilization = {
  151. "Literal": (3, 8),
  152. "TypeAlias": (3, 10),
  153. }
  154. assert name in stabilization
  155. if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
  156. self.typing_extensions_min = stabilization[name]
  157. return "typing_extensions." + name
  158. imp = path.replace("/", ".")
  159. if self.readable_stubs:
  160. self.from_imports[imp].add((name, None))
  161. return name
  162. else:
  163. self.imports.add(imp)
  164. return imp + "." + name
  165. def _import_message(self, name: str) -> str:
  166. """Import a referenced message and return a handle"""
  167. message_fd = self.descriptors.message_to_fd[name]
  168. assert message_fd.name.endswith(".proto")
  169. # Strip off package name
  170. if message_fd.package:
  171. assert name.startswith("." + message_fd.package + ".")
  172. name = name[len("." + message_fd.package + ".") :]
  173. else:
  174. assert name.startswith(".")
  175. name = name[1:]
  176. # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
  177. split = name.split(".")
  178. for i, part in enumerate(split):
  179. if part in PYTHON_RESERVED:
  180. split[i] = "_r_" + part
  181. name = ".".join(split)
  182. # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
  183. if not self.grpc and message_fd.name == self.fd.name:
  184. return name if self.readable_stubs else _mangle_global_identifier(name)
  185. # Not in file. Must import
  186. # Python generated code ignores proto packages, so the only relevant factor is
  187. # whether it is in the file or not.
  188. import_name = self._import(message_fd.name[:-6].replace("-", "_") + "_pb2", split[0])
  189. remains = ".".join(split[1:])
  190. if not remains:
  191. return import_name
  192. # remains could either be a direct import of a nested enum or message
  193. # from another package.
  194. return import_name + "." + remains
  195. def _builtin(self, name: str) -> str:
  196. return self._import("builtins", name)
  197. @contextmanager
  198. def _indent(self) -> Iterator[None]:
  199. self.indent = self.indent + " "
  200. yield
  201. self.indent = self.indent[:-4]
  202. def _write_line(self, line: str, *args: Any) -> None:
  203. if args:
  204. line = line.format(*args)
  205. if line == "":
  206. self.lines.append(line)
  207. else:
  208. self.lines.append(self.indent + line)
  209. def _break_text(self, text_block: str) -> List[str]:
  210. if text_block == "":
  211. return []
  212. return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")]
  213. def _has_comments(self, scl: SourceCodeLocation) -> bool:
  214. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  215. return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments)
  216. def _write_comments(self, scl: SourceCodeLocation) -> bool:
  217. """Return true if any comments were written"""
  218. if not self._has_comments(scl):
  219. return False
  220. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  221. assert sci_loc is not None
  222. leading_detached_lines = []
  223. leading_lines = []
  224. trailing_lines = []
  225. for leading_detached_comment in sci_loc.leading_detached_comments:
  226. leading_detached_lines = self._break_text(leading_detached_comment)
  227. if sci_loc.leading_comments is not None:
  228. leading_lines = self._break_text(sci_loc.leading_comments)
  229. # Trailing comments also go in the header - to make sure it gets into the docstring
  230. if sci_loc.trailing_comments is not None:
  231. trailing_lines = self._break_text(sci_loc.trailing_comments)
  232. lines = leading_detached_lines
  233. if leading_detached_lines and (leading_lines or trailing_lines):
  234. lines.append("")
  235. lines.extend(leading_lines)
  236. lines.extend(trailing_lines)
  237. lines = [
  238. # Escape triple-quotes that would otherwise end the docstring early.
  239. line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
  240. for line in lines
  241. ]
  242. if len(lines) == 1:
  243. line = lines[0]
  244. if line.endswith(('"', "\\")):
  245. # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
  246. # insert some whitespace to separate it from the closing quotes.
  247. # This is not necessary with multiline comments
  248. # because in that case we always insert a newline before the trailing triple-quotes.
  249. line = line + " "
  250. self._write_line(f'"""{line}"""')
  251. else:
  252. for i, line in enumerate(lines):
  253. if i == 0:
  254. self._write_line(f'"""{line}')
  255. else:
  256. self._write_line(f"{line}")
  257. self._write_line('"""')
  258. return True
  259. def write_enum_values(
  260. self,
  261. values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
  262. value_type: str,
  263. scl_prefix: SourceCodeLocation,
  264. ) -> None:
  265. for i, val in values:
  266. if val.name in PYTHON_RESERVED:
  267. continue
  268. scl = scl_prefix + [i]
  269. self._write_line(
  270. f"{val.name}: {value_type} # {val.number}",
  271. )
  272. self._write_comments(scl)
  273. def write_module_attributes(self) -> None:
  274. wl = self._write_line
  275. fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
  276. wl(f"DESCRIPTOR: {fd_type}")
  277. wl("")
  278. def write_enums(
  279. self,
  280. enums: Iterable[d.EnumDescriptorProto],
  281. prefix: str,
  282. scl_prefix: SourceCodeLocation,
  283. ) -> None:
  284. wl = self._write_line
  285. for i, enum in enumerate(enums):
  286. class_name = enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
  287. value_type_fq = prefix + class_name + ".ValueType"
  288. enum_helper_class = "_" + enum.name
  289. value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
  290. etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
  291. scl = scl_prefix + [i]
  292. wl(f"class {enum_helper_class}:")
  293. with self._indent():
  294. wl(
  295. 'ValueType = {}("ValueType", {})',
  296. self._import("typing", "NewType"),
  297. self._builtin("int"),
  298. )
  299. # Alias to the classic shorter definition "V"
  300. wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
  301. wl("")
  302. wl(
  303. "class {}({}[{}], {}): # noqa: F821",
  304. etw_helper_class,
  305. self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
  306. value_type_helper_fq,
  307. self._builtin("type"),
  308. )
  309. with self._indent():
  310. ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
  311. wl(f"DESCRIPTOR: {ed}")
  312. self.write_enum_values(
  313. [(i, v) for i, v in enumerate(enum.value) if v.name not in PROTO_ENUM_RESERVED],
  314. value_type_helper_fq,
  315. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  316. )
  317. wl("")
  318. if self._has_comments(scl):
  319. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
  320. with self._indent():
  321. self._write_comments(scl)
  322. wl("")
  323. else:
  324. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...")
  325. if prefix == "":
  326. wl("")
  327. self.write_enum_values(
  328. enumerate(enum.value),
  329. value_type_fq,
  330. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  331. )
  332. if prefix == "" and not self.readable_stubs:
  333. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  334. wl("")
  335. def write_messages(
  336. self,
  337. messages: Iterable[d.DescriptorProto],
  338. prefix: str,
  339. scl_prefix: SourceCodeLocation,
  340. ) -> None:
  341. wl = self._write_line
  342. for i, desc in enumerate(messages):
  343. qualified_name = prefix + desc.name
  344. # Reproduce some hardcoded logic from the protobuf implementation - where
  345. # some specific "well_known_types" generated protos to have additional
  346. # base classes
  347. addl_base = ""
  348. if self.fd.package + "." + desc.name in WKTBASES:
  349. # chop off the .proto - and import the well known type
  350. # eg `from google.protobuf.duration import Duration`
  351. well_known_type = WKTBASES[self.fd.package + "." + desc.name]
  352. addl_base = ", " + self._import(
  353. "google.protobuf.internal.well_known_types",
  354. well_known_type.__name__,
  355. )
  356. class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
  357. message_class = self._import("google.protobuf.message", "Message")
  358. wl(f"class {class_name}({message_class}{addl_base}):")
  359. with self._indent():
  360. scl = scl_prefix + [i]
  361. if self._write_comments(scl):
  362. wl("")
  363. desc_type = self._import("google.protobuf.descriptor", "Descriptor")
  364. wl(f"DESCRIPTOR: {desc_type}")
  365. wl("")
  366. # Nested enums/messages
  367. self.write_enums(
  368. desc.enum_type,
  369. qualified_name + ".",
  370. scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
  371. )
  372. self.write_messages(
  373. desc.nested_type,
  374. qualified_name + ".",
  375. scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
  376. )
  377. # integer constants for field numbers
  378. for f in desc.field:
  379. wl(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  380. for idx, field in enumerate(desc.field):
  381. if field.name in PYTHON_RESERVED:
  382. continue
  383. field_type = self.python_type(field)
  384. if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  385. # Scalar non repeated fields are r/w
  386. wl(f"{field.name}: {field_type}")
  387. self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx])
  388. else:
  389. # r/o Getters for non-scalar fields and scalar-repeated fields
  390. scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
  391. wl("@property")
  392. body = " ..." if not self._has_comments(scl_field) else ""
  393. wl(f"def {field.name}(self) -> {field_type}:{body}")
  394. if self._has_comments(scl_field):
  395. with self._indent():
  396. self._write_comments(scl_field)
  397. self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER])
  398. # Constructor
  399. wl("def __init__(")
  400. with self._indent():
  401. if any(f.name == "self" for f in desc.field):
  402. wl("# pyright: reportSelfClsParameterName=false")
  403. wl("self_,")
  404. else:
  405. wl("self,")
  406. with self._indent():
  407. constructor_fields = [f for f in desc.field if f.name not in PYTHON_RESERVED]
  408. if len(constructor_fields) > 0:
  409. # Only positional args allowed
  410. # See https://github.com/nipunn1313/mypy-protobuf/issues/71
  411. wl("*,")
  412. for field in constructor_fields:
  413. field_type = self.python_type(field, generic_container=True)
  414. if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional:
  415. wl(f"{field.name}: {field_type} = ...,")
  416. else:
  417. wl(f"{field.name}: {field_type} | None = ...,")
  418. wl(") -> None: ...")
  419. self.write_stringly_typed_fields(desc)
  420. if prefix == "" and not self.readable_stubs:
  421. wl("")
  422. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  423. wl("")
  424. def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
  425. """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
  426. wl = self._write_line
  427. # HasField, ClearField, WhichOneof accepts both bytes/str
  428. # HasField only supports singular. ClearField supports repeated as well
  429. # In proto3, HasField only supports message fields and optional fields
  430. # HasField always supports oneof fields
  431. hf_fields = [f.name for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))]
  432. cf_fields = [f.name for f in desc.field]
  433. wo_fields = {oneof.name: [f.name for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)}
  434. hf_fields.extend(wo_fields.keys())
  435. cf_fields.extend(wo_fields.keys())
  436. hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields))
  437. cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields))
  438. if not hf_fields and not cf_fields and not wo_fields:
  439. return
  440. if hf_fields:
  441. wl(
  442. "def HasField(self, field_name: {}[{}]) -> {}: ...",
  443. self._import("typing_extensions", "Literal"),
  444. hf_fields_text,
  445. self._builtin("bool"),
  446. )
  447. if cf_fields:
  448. wl(
  449. "def ClearField(self, field_name: {}[{}]) -> None: ...",
  450. self._import("typing_extensions", "Literal"),
  451. cf_fields_text,
  452. )
  453. for wo_field, members in sorted(wo_fields.items()):
  454. if len(wo_fields) > 1:
  455. wl("@{}", self._import("typing", "overload"))
  456. wl(
  457. "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...",
  458. self._import("typing_extensions", "Literal"),
  459. # Accepts both str and bytes
  460. f'"{wo_field}", b"{wo_field}"',
  461. self._import("typing_extensions", "Literal"),
  462. # Returns `str`
  463. ", ".join(f'"{m}"' for m in members),
  464. )
  465. def write_extensions(
  466. self,
  467. extensions: Sequence[d.FieldDescriptorProto],
  468. scl_prefix: SourceCodeLocation,
  469. ) -> None:
  470. wl = self._write_line
  471. for ext in extensions:
  472. wl(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  473. for i, ext in enumerate(extensions):
  474. scl = scl_prefix + [i]
  475. wl(
  476. "{}: {}[{}, {}]",
  477. ext.name,
  478. self._import(
  479. "google.protobuf.internal.extension_dict",
  480. "_ExtensionFieldDescriptor",
  481. ),
  482. self._import_message(ext.extendee),
  483. self.python_type(ext),
  484. )
  485. self._write_comments(scl)
  486. def write_methods(
  487. self,
  488. service: d.ServiceDescriptorProto,
  489. class_name: str,
  490. is_abstract: bool,
  491. scl_prefix: SourceCodeLocation,
  492. ) -> None:
  493. wl = self._write_line
  494. wl(
  495. "DESCRIPTOR: {}",
  496. self._import("google.protobuf.descriptor", "ServiceDescriptor"),
  497. )
  498. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  499. if not methods:
  500. wl("...")
  501. for i, method in methods:
  502. if is_abstract:
  503. wl("@{}", self._import("abc", "abstractmethod"))
  504. wl(f"def {method.name}(")
  505. with self._indent():
  506. wl(f"inst: {class_name},")
  507. wl(
  508. "rpc_controller: {},",
  509. self._import("google.protobuf.service", "RpcController"),
  510. )
  511. wl("request: {},", self._import_message(method.input_type))
  512. wl(
  513. "callback: {}[[{}], None] | None{},",
  514. self._import("collections.abc", "Callable"),
  515. self._import_message(method.output_type),
  516. "" if is_abstract else " = ...",
  517. )
  518. scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  519. wl(
  520. ") -> {}[{}]:{}",
  521. self._import("concurrent.futures", "Future"),
  522. self._import_message(method.output_type),
  523. " ..." if not self._has_comments(scl_method) else "",
  524. )
  525. if self._has_comments(scl_method):
  526. with self._indent():
  527. if not self._write_comments(scl_method):
  528. wl("...")
  529. def write_services(
  530. self,
  531. services: Iterable[d.ServiceDescriptorProto],
  532. scl_prefix: SourceCodeLocation,
  533. ) -> None:
  534. wl = self._write_line
  535. for i, service in enumerate(services):
  536. scl = scl_prefix + [i]
  537. class_name = service.name if service.name not in PYTHON_RESERVED else "_r_" + service.name
  538. # The service definition interface
  539. wl(
  540. "class {}({}, metaclass={}):",
  541. class_name,
  542. self._import("google.protobuf.service", "Service"),
  543. self._import("abc", "ABCMeta"),
  544. )
  545. with self._indent():
  546. if self._write_comments(scl):
  547. wl("")
  548. self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl)
  549. wl("")
  550. # The stub client
  551. stub_class_name = service.name + "_Stub"
  552. wl("class {}({}):", stub_class_name, class_name)
  553. with self._indent():
  554. if self._write_comments(scl):
  555. wl("")
  556. wl(
  557. "def __init__(self, rpc_channel: {}) -> None: ...",
  558. self._import("google.protobuf.service", "RpcChannel"),
  559. )
  560. self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl)
  561. wl("")
  562. def _import_casttype(self, casttype: str) -> str:
  563. split = casttype.split(".")
  564. assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
  565. pkg = split[0].replace("/", ".")
  566. return self._import(pkg, split[1])
  567. def _map_key_value_types(
  568. self,
  569. map_field: d.FieldDescriptorProto,
  570. key_field: d.FieldDescriptorProto,
  571. value_field: d.FieldDescriptorProto,
  572. ) -> Tuple[str, str]:
  573. oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype]
  574. if oldstyle_keytype:
  575. print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr)
  576. key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype
  577. ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field)
  578. oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype]
  579. if oldstyle_valuetype:
  580. print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr)
  581. value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype]
  582. vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field)
  583. return ktype, vtype
  584. def _callable_type(self, method: d.MethodDescriptorProto) -> str:
  585. if method.client_streaming:
  586. if method.server_streaming:
  587. return self._import("grpc", "StreamStreamMultiCallable")
  588. else:
  589. return self._import("grpc", "StreamUnaryMultiCallable")
  590. else:
  591. if method.server_streaming:
  592. return self._import("grpc", "UnaryStreamMultiCallable")
  593. else:
  594. return self._import("grpc", "UnaryUnaryMultiCallable")
  595. def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
  596. result = self._import_message(method.input_type)
  597. if use_stream_iterator and method.client_streaming:
  598. result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
  599. return result
  600. def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
  601. result = self._import_message(method.output_type)
  602. if use_stream_iterator and method.server_streaming:
  603. result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
  604. return result
  605. def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
  606. wl = self._write_line
  607. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  608. if not methods:
  609. wl("...")
  610. wl("")
  611. for i, method in methods:
  612. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  613. wl("@{}", self._import("abc", "abstractmethod"))
  614. wl("def {}(", method.name)
  615. with self._indent():
  616. wl("self,")
  617. input_name = "request_iterator" if method.client_streaming else "request"
  618. input_type = self._input_type(method)
  619. wl(f"{input_name}: {input_type},")
  620. wl("context: {},", self._import("grpc", "ServicerContext"))
  621. wl(
  622. ") -> {}:{}",
  623. self._output_type(method),
  624. " ..." if not self._has_comments(scl) else "",
  625. ),
  626. if self._has_comments(scl):
  627. with self._indent():
  628. if not self._write_comments(scl):
  629. wl("...")
  630. def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
  631. wl = self._write_line
  632. methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
  633. if not methods:
  634. wl("...")
  635. wl("")
  636. for i, method in methods:
  637. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  638. wl("{}: {}[", method.name, self._callable_type(method))
  639. with self._indent():
  640. wl("{},", self._input_type(method, False))
  641. wl("{},", self._output_type(method, False))
  642. wl("]")
  643. self._write_comments(scl)
  644. def write_grpc_services(
  645. self,
  646. services: Iterable[d.ServiceDescriptorProto],
  647. scl_prefix: SourceCodeLocation,
  648. ) -> None:
  649. wl = self._write_line
  650. for i, service in enumerate(services):
  651. if service.name in PYTHON_RESERVED:
  652. continue
  653. scl = scl_prefix + [i]
  654. # The stub client
  655. wl(f"class {service.name}Stub:")
  656. with self._indent():
  657. if self._write_comments(scl):
  658. wl("")
  659. wl(
  660. "def __init__(self, channel: {}) -> None: ...",
  661. self._import("grpc", "Channel"),
  662. )
  663. self.write_grpc_stub_methods(service, scl)
  664. wl("")
  665. # The service definition interface
  666. wl(
  667. "class {}Servicer(metaclass={}):",
  668. service.name,
  669. self._import("abc", "ABCMeta"),
  670. )
  671. with self._indent():
  672. if self._write_comments(scl):
  673. wl("")
  674. self.write_grpc_methods(service, scl)
  675. wl("")
  676. wl(
  677. "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
  678. service.name,
  679. service.name,
  680. self._import("grpc", "Server"),
  681. )
  682. wl("")
  683. def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str:
  684. """
  685. generic_container
  686. if set, type the field with generic interfaces. Eg.
  687. - Iterable[int] rather than RepeatedScalarFieldContainer[int]
  688. - Mapping[k, v] rather than MessageMap[k, v]
  689. Can be useful for input types (eg constructor)
  690. """
  691. oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype]
  692. if oldstyle_casttype:
  693. print(f"Warning: Field {field.name}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr)
  694. casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype
  695. if casttype:
  696. return self._import_casttype(casttype)
  697. mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
  698. d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
  699. d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
  700. d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
  701. d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
  702. d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
  703. d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
  704. d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
  705. d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
  706. d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
  707. d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
  708. d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
  709. d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
  710. d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
  711. d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"),
  712. d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
  713. d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"),
  714. d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
  715. d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
  716. }
  717. assert field.type in mapping, "Unrecognized type: " + repr(field.type)
  718. field_type = mapping[field.type]()
  719. # For non-repeated fields, we're done!
  720. if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  721. return field_type
  722. # Scalar repeated fields go in RepeatedScalarFieldContainer
  723. if is_scalar(field):
  724. container = (
  725. self._import("collections.abc", "Iterable")
  726. if generic_container
  727. else self._import(
  728. "google.protobuf.internal.containers",
  729. "RepeatedScalarFieldContainer",
  730. )
  731. )
  732. return f"{container}[{field_type}]"
  733. # non-scalar repeated map fields go in ScalarMap/MessageMap
  734. msg = self.descriptors.messages[field.type_name]
  735. if msg.options.map_entry:
  736. # map generates a special Entry wrapper message
  737. if generic_container:
  738. container = self._import("collections.abc", "Mapping")
  739. elif is_scalar(msg.field[1]):
  740. container = self._import("google.protobuf.internal.containers", "ScalarMap")
  741. else:
  742. container = self._import("google.protobuf.internal.containers", "MessageMap")
  743. ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
  744. return f"{container}[{ktype}, {vtype}]"
  745. # non-scalar repetated fields go in RepeatedCompositeFieldContainer
  746. container = (
  747. self._import("collections.abc", "Iterable")
  748. if generic_container
  749. else self._import(
  750. "google.protobuf.internal.containers",
  751. "RepeatedCompositeFieldContainer",
  752. )
  753. )
  754. return f"{container}[{field_type}]"
  755. def write(self) -> str:
  756. # save current module content, so that imports and module docstring can be inserted
  757. saved_lines = self.lines
  758. self.lines = []
  759. # module docstring may exist as comment before syntax (optional) or package name
  760. if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]):
  761. self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER])
  762. if self.lines:
  763. assert self.lines[0].startswith('"""')
  764. self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}'
  765. else:
  766. self._write_line(f'"""{HEADER}"""')
  767. for reexport_idx in self.fd.public_dependency:
  768. reexport_file = self.fd.dependency[reexport_idx]
  769. reexport_fd = self.descriptors.files[reexport_file]
  770. reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
  771. names = [m.name for m in reexport_fd.message_type] + [m.name for m in reexport_fd.enum_type] + [v.name for m in reexport_fd.enum_type for v in m.value] + [m.name for m in reexport_fd.extension]
  772. if reexport_fd.options.py_generic_services:
  773. names.extend(m.name for m in reexport_fd.service)
  774. if names:
  775. # n,n to force a reexport (from x import y as y)
  776. self.from_imports[reexport_imp].update((n, n) for n in names)
  777. if self.typing_extensions_min:
  778. self.imports.add("sys")
  779. for pkg in sorted(self.imports):
  780. self._write_line(f"import {pkg}")
  781. if self.typing_extensions_min:
  782. self._write_line("")
  783. self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:")
  784. self._write_line(" import typing as typing_extensions")
  785. self._write_line("else:")
  786. self._write_line(" import typing_extensions")
  787. for pkg, items in sorted(self.from_imports.items()):
  788. self._write_line(f"from {pkg} import (")
  789. for (name, reexport_name) in sorted(items):
  790. if reexport_name is None:
  791. self._write_line(f" {name},")
  792. else:
  793. self._write_line(f" {name} as {reexport_name},")
  794. self._write_line(")")
  795. self._write_line("")
  796. # restore module content
  797. self.lines += saved_lines
  798. content = "\n".join(self.lines)
  799. if not content.endswith("\n"):
  800. content = content + "\n"
  801. return content
  802. def is_scalar(fd: d.FieldDescriptorProto) -> bool:
  803. return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP)
  804. def generate_mypy_stubs(
  805. descriptors: Descriptors,
  806. response: plugin_pb2.CodeGeneratorResponse,
  807. quiet: bool,
  808. readable_stubs: bool,
  809. relax_strict_optional_primitives: bool,
  810. ) -> None:
  811. for name, fd in descriptors.to_generate.items():
  812. pkg_writer = PkgWriter(
  813. fd,
  814. descriptors,
  815. readable_stubs,
  816. relax_strict_optional_primitives,
  817. grpc=False,
  818. )
  819. pkg_writer.write_module_attributes()
  820. pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER])
  821. pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER])
  822. pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER])
  823. if fd.options.py_generic_services:
  824. pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  825. assert name == fd.name
  826. assert fd.name.endswith(".proto")
  827. output = response.file.add()
  828. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
  829. output.content = pkg_writer.write()
  830. def generate_mypy_grpc_stubs(
  831. descriptors: Descriptors,
  832. response: plugin_pb2.CodeGeneratorResponse,
  833. quiet: bool,
  834. readable_stubs: bool,
  835. relax_strict_optional_primitives: bool,
  836. ) -> None:
  837. for name, fd in descriptors.to_generate.items():
  838. pkg_writer = PkgWriter(
  839. fd,
  840. descriptors,
  841. readable_stubs,
  842. relax_strict_optional_primitives,
  843. grpc=True,
  844. )
  845. pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  846. assert name == fd.name
  847. assert fd.name.endswith(".proto")
  848. output = response.file.add()
  849. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
  850. output.content = pkg_writer.write()
  851. @contextmanager
  852. def code_generation() -> Iterator[
  853. Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
  854. ]:
  855. if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
  856. print("mypy-protobuf " + __version__)
  857. sys.exit(0)
  858. # Read request message from stdin
  859. data = sys.stdin.buffer.read()
  860. # Parse request
  861. request = plugin_pb2.CodeGeneratorRequest()
  862. request.ParseFromString(data)
  863. # Create response
  864. response = plugin_pb2.CodeGeneratorResponse()
  865. # Declare support for optional proto3 fields
  866. response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
  867. yield request, response
  868. # Serialise response message
  869. output = response.SerializeToString()
  870. # Write to stdout
  871. sys.stdout.buffer.write(output)
  872. def main() -> None:
  873. # Generate mypy
  874. with code_generation() as (request, response):
  875. generate_mypy_stubs(
  876. Descriptors(request),
  877. response,
  878. "quiet" in request.parameter,
  879. "readable_stubs" in request.parameter,
  880. "relax_strict_optional_primitives" in request.parameter,
  881. )
  882. def grpc() -> None:
  883. # Generate grpc mypy
  884. with code_generation() as (request, response):
  885. generate_mypy_grpc_stubs(
  886. Descriptors(request),
  887. response,
  888. "quiet" in request.parameter,
  889. "readable_stubs" in request.parameter,
  890. "relax_strict_optional_primitives" in request.parameter,
  891. )
  892. if __name__ == "__main__":
  893. main()