40 KB

  1. #!/usr/bin/env python
  2. """Protoc Plugin to generate mypy stubs."""
  3. from __future__ import annotations
  4. import sys
  5. from collections import defaultdict
  6. from contextlib import contextmanager
  7. from typing import (
  8. Any,
  9. Callable,
  10. Dict,
  11. Iterable,
  12. Iterator,
  13. List,
  14. Optional,
  15. Set,
  16. Sequence,
  17. Tuple,
  18. )
  19. import google.protobuf.descriptor_pb2 as d
  20. from google.protobuf.compiler import plugin_pb2 as plugin_pb2
  21. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  22. from google.protobuf.internal.well_known_types import WKTBASES
  23. from . import extensions_pb2
  24. __version__ = "3.3.0"
  25. # SourceCodeLocation is defined by `message Location` here
  26. #
  27. SourceCodeLocation = List[int]
  28. # So phabricator doesn't think is generated
  29. GENERATED = "@ge" + "nerated"
  30. HEADER = f"""
  31. {GENERATED} by mypy-protobuf. Do not edit manually!
  32. isort:skip_file
  33. """
  34. # See for details
  36. "False",
  37. "None",
  38. "True",
  39. "and",
  40. "as",
  41. "async",
  42. "await",
  43. "assert",
  44. "break",
  45. "class",
  46. "continue",
  47. "def",
  48. "del",
  49. "elif",
  50. "else",
  51. "except",
  52. "finally",
  53. "for",
  54. "from",
  55. "global",
  56. "if",
  57. "import",
  58. "in",
  59. "is",
  60. "lambda",
  61. "nonlocal",
  62. "not",
  63. "or",
  64. "pass",
  65. "raise",
  66. "return",
  67. "try",
  68. "while",
  69. "with",
  70. "yield",
  71. }
  73. "Name",
  74. "Value",
  75. "keys",
  76. "values",
  77. "items",
  78. }
  79. def _mangle_global_identifier(name: str) -> str:
  80. """
  81. Module level identifiers are mangled and aliased so that they can be disambiguated
  82. from fields/enum variants with the same name within the file.
  83. Eg:
  84. Enum variant `Name` or message field `Name` might conflict with a top level
  85. message or enum named `Name`, so mangle it with a global___ prefix for
  86. internal references. Note that this doesn't affect inner enums/messages
  87. because they get fuly qualified when referenced within a file"""
  88. return f"global___{name}"
  89. class Descriptors(object):
  90. def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
  91. files = { f for f in request.proto_file}
  92. to_generate = {n: files[n] for n in request.file_to_generate}
  93. self.files: Dict[str, d.FileDescriptorProto] = files
  94. self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
  95. self.messages: Dict[str, d.DescriptorProto] = {}
  96. self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
  97. def _add_enums(
  98. enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
  99. prefix: str,
  100. _fd: d.FileDescriptorProto,
  101. ) -> None:
  102. for enum in enums:
  103. self.message_to_fd[prefix +] = _fd
  104. self.message_to_fd[prefix + + ".ValueType"] = _fd
  105. def _add_messages(
  106. messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
  107. prefix: str,
  108. _fd: d.FileDescriptorProto,
  109. ) -> None:
  110. for message in messages:
  111. self.messages[prefix +] = message
  112. self.message_to_fd[prefix +] = _fd
  113. sub_prefix = prefix + + "."
  114. _add_messages(message.nested_type, sub_prefix, _fd)
  115. _add_enums(message.enum_type, sub_prefix, _fd)
  116. for fd in request.proto_file:
  117. start_prefix = "." + fd.package + "." if fd.package else "."
  118. _add_messages(fd.message_type, start_prefix, fd)
  119. _add_enums(fd.enum_type, start_prefix, fd)
  120. class PkgWriter(object):
  121. """Writes a single pyi file"""
  122. def __init__(
  123. self,
  124. fd: d.FileDescriptorProto,
  125. descriptors: Descriptors,
  126. readable_stubs: bool,
  127. relax_strict_optional_primitives: bool,
  128. grpc: bool,
  129. ) -> None:
  130. self.fd = fd
  131. self.descriptors = descriptors
  132. self.readable_stubs = readable_stubs
  133. self.relax_strict_optional_primitives = relax_strict_optional_primitives
  134. self.grpc = grpc
  135. self.lines: List[str] = []
  136. self.indent = ""
  137. # Set of {x}, where {x} corresponds to to `import {x}`
  138. self.imports: Set[str] = set()
  139. # dictionary of x->(y,z) for `from {x} import {y} as {z}`
  140. # if {z} is None, then it shortens to `from {x} import {y}`
  141. self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set)
  142. self.typing_extensions_min: Optional[Tuple[int, int]] = None
  143. # Comments
  144. self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location}
  145. def _import(self, path: str, name: str) -> str:
  146. """Imports a stdlib path and returns a handle to it
  147. eg. self._import("typing", "Literal") -> "Literal"
  148. """
  149. if path == "typing_extensions":
  150. stabilization = {
  151. "Literal": (3, 8),
  152. "TypeAlias": (3, 10),
  153. }
  154. assert name in stabilization
  155. if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
  156. self.typing_extensions_min = stabilization[name]
  157. return "typing_extensions." + name
  158. imp = path.replace("/", ".")
  159. if self.readable_stubs:
  160. self.from_imports[imp].add((name, None))
  161. return name
  162. else:
  163. self.imports.add(imp)
  164. return imp + "." + name
  165. def _import_message(self, name: str) -> str:
  166. """Import a referenced message and return a handle"""
  167. message_fd = self.descriptors.message_to_fd[name]
  168. assert".proto")
  169. # Strip off package name
  170. if message_fd.package:
  171. assert name.startswith("." + message_fd.package + ".")
  172. name = name[len("." + message_fd.package + ".") :]
  173. else:
  174. assert name.startswith(".")
  175. name = name[1:]
  176. # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
  177. split = name.split(".")
  178. for i, part in enumerate(split):
  179. if part in PYTHON_RESERVED:
  180. split[i] = "_r_" + part
  181. name = ".".join(split)
  182. # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
  183. if not self.grpc and ==
  184. return name if self.readable_stubs else _mangle_global_identifier(name)
  185. # Not in file. Must import
  186. # Python generated code ignores proto packages, so the only relevant factor is
  187. # whether it is in the file or not.
  188. import_name = self._import([:-6].replace("-", "_") + "_pb2", split[0])
  189. remains = ".".join(split[1:])
  190. if not remains:
  191. return import_name
  192. # remains could either be a direct import of a nested enum or message
  193. # from another package.
  194. return import_name + "." + remains
  195. def _builtin(self, name: str) -> str:
  196. return self._import("builtins", name)
  197. @contextmanager
  198. def _indent(self) -> Iterator[None]:
  199. self.indent = self.indent + " "
  200. yield
  201. self.indent = self.indent[:-4]
  202. def _write_line(self, line: str, *args: Any) -> None:
  203. if args:
  204. line = line.format(*args)
  205. if line == "":
  206. self.lines.append(line)
  207. else:
  208. self.lines.append(self.indent + line)
  209. def _break_text(self, text_block: str) -> List[str]:
  210. if text_block == "":
  211. return []
  212. return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")]
  213. def _has_comments(self, scl: SourceCodeLocation) -> bool:
  214. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  215. return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments)
  216. def _write_comments(self, scl: SourceCodeLocation) -> bool:
  217. """Return true if any comments were written"""
  218. if not self._has_comments(scl):
  219. return False
  220. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  221. assert sci_loc is not None
  222. leading_detached_lines = []
  223. leading_lines = []
  224. trailing_lines = []
  225. for leading_detached_comment in sci_loc.leading_detached_comments:
  226. leading_detached_lines = self._break_text(leading_detached_comment)
  227. if sci_loc.leading_comments is not None:
  228. leading_lines = self._break_text(sci_loc.leading_comments)
  229. # Trailing comments also go in the header - to make sure it gets into the docstring
  230. if sci_loc.trailing_comments is not None:
  231. trailing_lines = self._break_text(sci_loc.trailing_comments)
  232. lines = leading_detached_lines
  233. if leading_detached_lines and (leading_lines or trailing_lines):
  234. lines.append("")
  235. lines.extend(leading_lines)
  236. lines.extend(trailing_lines)
  237. lines = [
  238. # Escape triple-quotes that would otherwise end the docstring early.
  239. line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
  240. for line in lines
  241. ]
  242. if len(lines) == 1:
  243. line = lines[0]
  244. if line.endswith(('"', "\\")):
  245. # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
  246. # insert some whitespace to separate it from the closing quotes.
  247. # This is not necessary with multiline comments
  248. # because in that case we always insert a newline before the trailing triple-quotes.
  249. line = line + " "
  250. self._write_line(f'"""{line}"""')
  251. else:
  252. for i, line in enumerate(lines):
  253. if i == 0:
  254. self._write_line(f'"""{line}')
  255. else:
  256. self._write_line(f"{line}")
  257. self._write_line('"""')
  258. return True
  259. def write_enum_values(
  260. self,
  261. values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
  262. value_type: str,
  263. scl_prefix: SourceCodeLocation,
  264. ) -> None:
  265. for i, val in values:
  266. if in PYTHON_RESERVED:
  267. continue
  268. scl = scl_prefix + [i]
  269. self._write_line(
  270. f"{}: {value_type} # {val.number}",
  271. )
  272. self._write_comments(scl)
  273. def write_module_attributes(self) -> None:
  274. wl = self._write_line
  275. fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
  276. wl(f"DESCRIPTOR: {fd_type}")
  277. wl("")
  278. def write_enums(
  279. self,
  280. enums: Iterable[d.EnumDescriptorProto],
  281. prefix: str,
  282. scl_prefix: SourceCodeLocation,
  283. ) -> None:
  284. wl = self._write_line
  285. for i, enum in enumerate(enums):
  286. class_name = if not in PYTHON_RESERVED else "_r_" +
  287. value_type_fq = prefix + class_name + ".ValueType"
  288. enum_helper_class = "_" +
  289. value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
  290. etw_helper_class = "_" + + "EnumTypeWrapper"
  291. scl = scl_prefix + [i]
  292. wl(f"class {enum_helper_class}:")
  293. with self._indent():
  294. wl(
  295. 'ValueType = {}("ValueType", {})',
  296. self._import("typing", "NewType"),
  297. self._builtin("int"),
  298. )
  299. # Alias to the classic shorter definition "V"
  300. wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
  301. wl("")
  302. wl(
  303. "class {}({}[{}], {}): # noqa: F821",
  304. etw_helper_class,
  305. self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
  306. value_type_helper_fq,
  307. self._builtin("type"),
  308. )
  309. with self._indent():
  310. ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
  311. wl(f"DESCRIPTOR: {ed}")
  312. self.write_enum_values(
  313. [(i, v) for i, v in enumerate(enum.value) if not in PROTO_ENUM_RESERVED],
  314. value_type_helper_fq,
  315. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  316. )
  317. wl("")
  318. if self._has_comments(scl):
  319. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
  320. with self._indent():
  321. self._write_comments(scl)
  322. wl("")
  323. else:
  324. wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...")
  325. if prefix == "":
  326. wl("")
  327. self.write_enum_values(
  328. enumerate(enum.value),
  329. value_type_fq,
  330. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  331. )
  332. if prefix == "" and not self.readable_stubs:
  333. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  334. wl("")
  335. def write_messages(
  336. self,
  337. messages: Iterable[d.DescriptorProto],
  338. prefix: str,
  339. scl_prefix: SourceCodeLocation,
  340. ) -> None:
  341. wl = self._write_line
  342. for i, desc in enumerate(messages):
  343. qualified_name = prefix +
  344. # Reproduce some hardcoded logic from the protobuf implementation - where
  345. # some specific "well_known_types" generated protos to have additional
  346. # base classes
  347. addl_base = ""
  348. if self.fd.package + "." + in WKTBASES:
  349. # chop off the .proto - and import the well known type
  350. # eg `from google.protobuf.duration import Duration`
  351. well_known_type = WKTBASES[self.fd.package + "." +]
  352. addl_base = ", " + self._import(
  353. "google.protobuf.internal.well_known_types",
  354. well_known_type.__name__,
  355. )
  356. class_name = if not in PYTHON_RESERVED else "_r_" +
  357. message_class = self._import("google.protobuf.message", "Message")
  358. wl(f"class {class_name}({message_class}{addl_base}):")
  359. with self._indent():
  360. scl = scl_prefix + [i]
  361. if self._write_comments(scl):
  362. wl("")
  363. desc_type = self._import("google.protobuf.descriptor", "Descriptor")
  364. wl(f"DESCRIPTOR: {desc_type}")
  365. wl("")
  366. # Nested enums/messages
  367. self.write_enums(
  368. desc.enum_type,
  369. qualified_name + ".",
  370. scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
  371. )
  372. self.write_messages(
  373. desc.nested_type,
  374. qualified_name + ".",
  375. scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
  376. )
  377. # integer constants for field numbers
  378. for f in desc.field:
  379. wl(f"{}_FIELD_NUMBER: {self._builtin('int')}")
  380. for idx, field in enumerate(desc.field):
  381. if in PYTHON_RESERVED:
  382. continue
  383. field_type = self.python_type(field)
  384. if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  385. # Scalar non repeated fields are r/w
  386. wl(f"{}: {field_type}")
  387. self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx])
  388. else:
  389. # r/o Getters for non-scalar fields and scalar-repeated fields
  390. scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
  391. wl("@property")
  392. body = " ..." if not self._has_comments(scl_field) else ""
  393. wl(f"def {}(self) -> {field_type}:{body}")
  394. if self._has_comments(scl_field):
  395. with self._indent():
  396. self._write_comments(scl_field)
  397. self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER])
  398. # Constructor
  399. wl("def __init__(")
  400. with self._indent():
  401. if any( == "self" for f in desc.field):
  402. wl("# pyright: reportSelfClsParameterName=false")
  403. wl("self_,")
  404. else:
  405. wl("self,")
  406. with self._indent():
  407. constructor_fields = [f for f in desc.field if not in PYTHON_RESERVED]
  408. if len(constructor_fields) > 0:
  409. # Only positional args allowed
  410. # See
  411. wl("*,")
  412. for field in constructor_fields:
  413. field_type = self.python_type(field, generic_container=True)
  414. if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional:
  415. wl(f"{}: {field_type} = ...,")
  416. else:
  417. wl(f"{}: {field_type} | None = ...,")
  418. wl(") -> None: ...")
  419. self.write_stringly_typed_fields(desc)
  420. if prefix == "" and not self.readable_stubs:
  421. wl("")
  422. wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
  423. wl("")
  424. def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
  425. """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
  426. wl = self._write_line
  427. # HasField, ClearField, WhichOneof accepts both bytes/str
  428. # HasField only supports singular. ClearField supports repeated as well
  429. # In proto3, HasField only supports message fields and optional fields
  430. # HasField always supports oneof fields
  431. hf_fields = [ for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))]
  432. cf_fields = [ for f in desc.field]
  433. wo_fields = { [ for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)}
  434. hf_fields.extend(wo_fields.keys())
  435. cf_fields.extend(wo_fields.keys())
  436. hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields))
  437. cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields))
  438. if not hf_fields and not cf_fields and not wo_fields:
  439. return
  440. if hf_fields:
  441. wl(
  442. "def HasField(self, field_name: {}[{}]) -> {}: ...",
  443. self._import("typing_extensions", "Literal"),
  444. hf_fields_text,
  445. self._builtin("bool"),
  446. )
  447. if cf_fields:
  448. wl(
  449. "def ClearField(self, field_name: {}[{}]) -> None: ...",
  450. self._import("typing_extensions", "Literal"),
  451. cf_fields_text,
  452. )
  453. for wo_field, members in sorted(wo_fields.items()):
  454. if len(wo_fields) > 1:
  455. wl("@{}", self._import("typing", "overload"))
  456. wl(
  457. "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...",
  458. self._import("typing_extensions", "Literal"),
  459. # Accepts both str and bytes
  460. f'"{wo_field}", b"{wo_field}"',
  461. self._import("typing_extensions", "Literal"),
  462. # Returns `str`
  463. ", ".join(f'"{m}"' for m in members),
  464. )
  465. def write_extensions(
  466. self,
  467. extensions: Sequence[d.FieldDescriptorProto],
  468. scl_prefix: SourceCodeLocation,
  469. ) -> None:
  470. wl = self._write_line
  471. for ext in extensions:
  472. wl(f"{}_FIELD_NUMBER: {self._builtin('int')}")
  473. for i, ext in enumerate(extensions):
  474. scl = scl_prefix + [i]
  475. wl(
  476. "{}: {}[{}, {}]",
  478. self._import(
  479. "google.protobuf.internal.extension_dict",
  480. "_ExtensionFieldDescriptor",
  481. ),
  482. self._import_message(ext.extendee),
  483. self.python_type(ext),
  484. )
  485. self._write_comments(scl)
  486. def write_methods(
  487. self,
  488. service: d.ServiceDescriptorProto,
  489. class_name: str,
  490. is_abstract: bool,
  491. scl_prefix: SourceCodeLocation,
  492. ) -> None:
  493. wl = self._write_line
  494. wl(
  495. "DESCRIPTOR: {}",
  496. self._import("google.protobuf.descriptor", "ServiceDescriptor"),
  497. )
  498. methods = [(i, m) for i, m in enumerate(service.method) if not in PYTHON_RESERVED]
  499. if not methods:
  500. wl("...")
  501. for i, method in methods:
  502. if is_abstract:
  503. wl("@{}", self._import("abc", "abstractmethod"))
  504. wl(f"def {}(")
  505. with self._indent():
  506. wl(f"inst: {class_name},")
  507. wl(
  508. "rpc_controller: {},",
  509. self._import("google.protobuf.service", "RpcController"),
  510. )
  511. wl("request: {},", self._import_message(method.input_type))
  512. wl(
  513. "callback: {}[[{}], None] | None{},",
  514. self._import("", "Callable"),
  515. self._import_message(method.output_type),
  516. "" if is_abstract else " = ...",
  517. )
  518. scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  519. wl(
  520. ") -> {}[{}]:{}",
  521. self._import("concurrent.futures", "Future"),
  522. self._import_message(method.output_type),
  523. " ..." if not self._has_comments(scl_method) else "",
  524. )
  525. if self._has_comments(scl_method):
  526. with self._indent():
  527. if not self._write_comments(scl_method):
  528. wl("...")
  529. def write_services(
  530. self,
  531. services: Iterable[d.ServiceDescriptorProto],
  532. scl_prefix: SourceCodeLocation,
  533. ) -> None:
  534. wl = self._write_line
  535. for i, service in enumerate(services):
  536. scl = scl_prefix + [i]
  537. class_name = if not in PYTHON_RESERVED else "_r_" +
  538. # The service definition interface
  539. wl(
  540. "class {}({}, metaclass={}):",
  541. class_name,
  542. self._import("google.protobuf.service", "Service"),
  543. self._import("abc", "ABCMeta"),
  544. )
  545. with self._indent():
  546. if self._write_comments(scl):
  547. wl("")
  548. self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl)
  549. wl("")
  550. # The stub client
  551. stub_class_name = + "_Stub"
  552. wl("class {}({}):", stub_class_name, class_name)
  553. with self._indent():
  554. if self._write_comments(scl):
  555. wl("")
  556. wl(
  557. "def __init__(self, rpc_channel: {}) -> None: ...",
  558. self._import("google.protobuf.service", "RpcChannel"),
  559. )
  560. self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl)
  561. wl("")
  562. def _import_casttype(self, casttype: str) -> str:
  563. split = casttype.split(".")
  564. assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
  565. pkg = split[0].replace("/", ".")
  566. return self._import(pkg, split[1])
  567. def _map_key_value_types(
  568. self,
  569. map_field: d.FieldDescriptorProto,
  570. key_field: d.FieldDescriptorProto,
  571. value_field: d.FieldDescriptorProto,
  572. ) -> Tuple[str, str]:
  573. oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype]
  574. if oldstyle_keytype:
  575. print(f"Warning: Map Field {}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr)
  576. key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype
  577. ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field)
  578. oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype]
  579. if oldstyle_valuetype:
  580. print(f"Warning: Map Field {}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr)
  581. value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype]
  582. vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field)
  583. return ktype, vtype
  584. def _callable_type(self, method: d.MethodDescriptorProto) -> str:
  585. if method.client_streaming:
  586. if method.server_streaming:
  587. return self._import("grpc", "StreamStreamMultiCallable")
  588. else:
  589. return self._import("grpc", "StreamUnaryMultiCallable")
  590. else:
  591. if method.server_streaming:
  592. return self._import("grpc", "UnaryStreamMultiCallable")
  593. else:
  594. return self._import("grpc", "UnaryUnaryMultiCallable")
  595. def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
  596. result = self._import_message(method.input_type)
  597. if use_stream_iterator and method.client_streaming:
  598. result = f"{self._import('', 'Iterator')}[{result}]"
  599. return result
  600. def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
  601. result = self._import_message(method.output_type)
  602. if use_stream_iterator and method.server_streaming:
  603. result = f"{self._import('', 'Iterator')}[{result}]"
  604. return result
  605. def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
  606. wl = self._write_line
  607. methods = [(i, m) for i, m in enumerate(service.method) if not in PYTHON_RESERVED]
  608. if not methods:
  609. wl("...")
  610. wl("")
  611. for i, method in methods:
  612. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  613. wl("@{}", self._import("abc", "abstractmethod"))
  614. wl("def {}(",
  615. with self._indent():
  616. wl("self,")
  617. input_name = "request_iterator" if method.client_streaming else "request"
  618. input_type = self._input_type(method)
  619. wl(f"{input_name}: {input_type},")
  620. wl("context: {},", self._import("grpc", "ServicerContext"))
  621. wl(
  622. ") -> {}:{}",
  623. self._output_type(method),
  624. " ..." if not self._has_comments(scl) else "",
  625. ),
  626. if self._has_comments(scl):
  627. with self._indent():
  628. if not self._write_comments(scl):
  629. wl("...")
  630. def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
  631. wl = self._write_line
  632. methods = [(i, m) for i, m in enumerate(service.method) if not in PYTHON_RESERVED]
  633. if not methods:
  634. wl("...")
  635. wl("")
  636. for i, method in methods:
  637. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  638. wl("{}: {}[",, self._callable_type(method))
  639. with self._indent():
  640. wl("{},", self._input_type(method, False))
  641. wl("{},", self._output_type(method, False))
  642. wl("]")
  643. self._write_comments(scl)
  644. def write_grpc_services(
  645. self,
  646. services: Iterable[d.ServiceDescriptorProto],
  647. scl_prefix: SourceCodeLocation,
  648. ) -> None:
  649. wl = self._write_line
  650. for i, service in enumerate(services):
  651. if in PYTHON_RESERVED:
  652. continue
  653. scl = scl_prefix + [i]
  654. # The stub client
  655. wl(f"class {}Stub:")
  656. with self._indent():
  657. if self._write_comments(scl):
  658. wl("")
  659. wl(
  660. "def __init__(self, channel: {}) -> None: ...",
  661. self._import("grpc", "Channel"),
  662. )
  663. self.write_grpc_stub_methods(service, scl)
  664. wl("")
  665. # The service definition interface
  666. wl(
  667. "class {}Servicer(metaclass={}):",
  669. self._import("abc", "ABCMeta"),
  670. )
  671. with self._indent():
  672. if self._write_comments(scl):
  673. wl("")
  674. self.write_grpc_methods(service, scl)
  675. wl("")
  676. wl(
  677. "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
  680. self._import("grpc", "Server"),
  681. )
  682. wl("")
  683. def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str:
  684. """
  685. generic_container
  686. if set, type the field with generic interfaces. Eg.
  687. - Iterable[int] rather than RepeatedScalarFieldContainer[int]
  688. - Mapping[k, v] rather than MessageMap[k, v]
  689. Can be useful for input types (eg constructor)
  690. """
  691. oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype]
  692. if oldstyle_casttype:
  693. print(f"Warning: Field {}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr)
  694. casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype
  695. if casttype:
  696. return self._import_casttype(casttype)
  697. mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
  698. d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
  699. d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
  700. d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
  701. d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
  702. d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
  703. d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
  704. d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
  705. d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
  706. d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
  707. d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
  708. d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
  709. d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
  710. d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
  711. d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"),
  712. d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
  713. d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"),
  714. d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
  715. d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
  716. }
  717. assert field.type in mapping, "Unrecognized type: " + repr(field.type)
  718. field_type = mapping[field.type]()
  719. # For non-repeated fields, we're done!
  720. if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  721. return field_type
  722. # Scalar repeated fields go in RepeatedScalarFieldContainer
  723. if is_scalar(field):
  724. container = (
  725. self._import("", "Iterable")
  726. if generic_container
  727. else self._import(
  728. "google.protobuf.internal.containers",
  729. "RepeatedScalarFieldContainer",
  730. )
  731. )
  732. return f"{container}[{field_type}]"
  733. # non-scalar repeated map fields go in ScalarMap/MessageMap
  734. msg = self.descriptors.messages[field.type_name]
  735. if msg.options.map_entry:
  736. # map generates a special Entry wrapper message
  737. if generic_container:
  738. container = self._import("", "Mapping")
  739. elif is_scalar(msg.field[1]):
  740. container = self._import("google.protobuf.internal.containers", "ScalarMap")
  741. else:
  742. container = self._import("google.protobuf.internal.containers", "MessageMap")
  743. ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
  744. return f"{container}[{ktype}, {vtype}]"
  745. # non-scalar repetated fields go in RepeatedCompositeFieldContainer
  746. container = (
  747. self._import("", "Iterable")
  748. if generic_container
  749. else self._import(
  750. "google.protobuf.internal.containers",
  751. "RepeatedCompositeFieldContainer",
  752. )
  753. )
  754. return f"{container}[{field_type}]"
  755. def write(self) -> str:
  756. # save current module content, so that imports and module docstring can be inserted
  757. saved_lines = self.lines
  758. self.lines = []
  759. # module docstring may exist as comment before syntax (optional) or package name
  760. if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]):
  761. self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER])
  762. if self.lines:
  763. assert self.lines[0].startswith('"""')
  764. self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}'
  765. else:
  766. self._write_line(f'"""{HEADER}"""')
  767. for reexport_idx in self.fd.public_dependency:
  768. reexport_file = self.fd.dependency[reexport_idx]
  769. reexport_fd = self.descriptors.files[reexport_file]
  770. reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
  771. names = [ for m in reexport_fd.message_type] + [ for m in reexport_fd.enum_type] + [ for m in reexport_fd.enum_type for v in m.value] + [ for m in reexport_fd.extension]
  772. if reexport_fd.options.py_generic_services:
  773. names.extend( for m in reexport_fd.service)
  774. if names:
  775. # n,n to force a reexport (from x import y as y)
  776. self.from_imports[reexport_imp].update((n, n) for n in names)
  777. if self.typing_extensions_min:
  778. self.imports.add("sys")
  779. for pkg in sorted(self.imports):
  780. self._write_line(f"import {pkg}")
  781. if self.typing_extensions_min:
  782. self._write_line("")
  783. self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:")
  784. self._write_line(" import typing as typing_extensions")
  785. self._write_line("else:")
  786. self._write_line(" import typing_extensions")
  787. for pkg, items in sorted(self.from_imports.items()):
  788. self._write_line(f"from {pkg} import (")
  789. for (name, reexport_name) in sorted(items):
  790. if reexport_name is None:
  791. self._write_line(f" {name},")
  792. else:
  793. self._write_line(f" {name} as {reexport_name},")
  794. self._write_line(")")
  795. self._write_line("")
  796. # restore module content
  797. self.lines += saved_lines
  798. content = "\n".join(self.lines)
  799. if not content.endswith("\n"):
  800. content = content + "\n"
  801. return content
  802. def is_scalar(fd: d.FieldDescriptorProto) -> bool:
  803. return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP)
  804. def generate_mypy_stubs(
  805. descriptors: Descriptors,
  806. response: plugin_pb2.CodeGeneratorResponse,
  807. quiet: bool,
  808. readable_stubs: bool,
  809. relax_strict_optional_primitives: bool,
  810. ) -> None:
  811. for name, fd in descriptors.to_generate.items():
  812. pkg_writer = PkgWriter(
  813. fd,
  814. descriptors,
  815. readable_stubs,
  816. relax_strict_optional_primitives,
  817. grpc=False,
  818. )
  819. pkg_writer.write_module_attributes()
  820. pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER])
  821. pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER])
  822. pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER])
  823. if fd.options.py_generic_services:
  824. pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  825. assert name ==
  826. assert".proto")
  827. output = response.file.add()
  828. =[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
  829. output.content = pkg_writer.write()
  830. def generate_mypy_grpc_stubs(
  831. descriptors: Descriptors,
  832. response: plugin_pb2.CodeGeneratorResponse,
  833. quiet: bool,
  834. readable_stubs: bool,
  835. relax_strict_optional_primitives: bool,
  836. ) -> None:
  837. for name, fd in descriptors.to_generate.items():
  838. pkg_writer = PkgWriter(
  839. fd,
  840. descriptors,
  841. readable_stubs,
  842. relax_strict_optional_primitives,
  843. grpc=True,
  844. )
  845. pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
  846. assert name ==
  847. assert".proto")
  848. output = response.file.add()
  849. =[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
  850. output.content = pkg_writer.write()
  851. @contextmanager
  852. def code_generation() -> Iterator[
  853. Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
  854. ]:
  855. if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
  856. print("mypy-protobuf " + __version__)
  857. sys.exit(0)
  858. # Read request message from stdin
  859. data =
  860. # Parse request
  861. request = plugin_pb2.CodeGeneratorRequest()
  862. request.ParseFromString(data)
  863. # Create response
  864. response = plugin_pb2.CodeGeneratorResponse()
  865. # Declare support for optional proto3 fields
  866. response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
  867. yield request, response
  868. # Serialise response message
  869. output = response.SerializeToString()
  870. # Write to stdout
  871. sys.stdout.buffer.write(output)
  872. def main() -> None:
  873. # Generate mypy
  874. with code_generation() as (request, response):
  875. generate_mypy_stubs(
  876. Descriptors(request),
  877. response,
  878. "quiet" in request.parameter,
  879. "readable_stubs" in request.parameter,
  880. "relax_strict_optional_primitives" in request.parameter,
  881. )
  882. def grpc() -> None:
  883. # Generate grpc mypy
  884. with code_generation() as (request, response):
  885. generate_mypy_grpc_stubs(
  886. Descriptors(request),
  887. response,
  888. "quiet" in request.parameter,
  889. "readable_stubs" in request.parameter,
  890. "relax_strict_optional_primitives" in request.parameter,
  891. )
  892. if __name__ == "__main__":
  893. main()