main.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086
  1. #!/usr/bin/env python
  2. """Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation"""
  3. import os
  4. import sys
  5. from collections import defaultdict
  6. from contextlib import contextmanager
  7. from functools import wraps
  8. from typing import (
  9. Any,
  10. Callable,
  11. Dict,
  12. Iterable,
  13. Iterator,
  14. List,
  15. Optional,
  16. Set,
  17. Sequence,
  18. Tuple,
  19. )
  20. import google.protobuf.descriptor_pb2 as d
  21. from google.protobuf.compiler import plugin_pb2 as plugin_pb2
  22. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  23. from google.protobuf.internal.well_known_types import WKTBASES
  24. from . import extensions_pb2
  25. __version__ = "3.2.0"
  26. # SourceCodeLocation is defined by `message Location` here
  27. # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
  28. SourceCodeLocation = List[int]
  29. # So phabricator doesn't think mypy_protobuf.py is generated
  30. GENERATED = "@ge" + "nerated"
  31. HEADER = f"""\"\"\"
  32. {GENERATED} by mypy-protobuf. Do not edit manually!
  33. isort:skip_file
  34. \"\"\"
  35. """
  36. # See https://github.com/dropbox/mypy-protobuf/issues/73 for details
  37. PYTHON_RESERVED = {
  38. "False",
  39. "None",
  40. "True",
  41. "and",
  42. "as",
  43. "async",
  44. "await",
  45. "assert",
  46. "break",
  47. "class",
  48. "continue",
  49. "def",
  50. "del",
  51. "elif",
  52. "else",
  53. "except",
  54. "finally",
  55. "for",
  56. "from",
  57. "global",
  58. "if",
  59. "import",
  60. "in",
  61. "is",
  62. "lambda",
  63. "nonlocal",
  64. "not",
  65. "or",
  66. "pass",
  67. "raise",
  68. "return",
  69. "try",
  70. "while",
  71. "with",
  72. "yield",
  73. }
  74. PROTO_ENUM_RESERVED = {
  75. "Name",
  76. "Value",
  77. "keys",
  78. "values",
  79. "items",
  80. }
  81. def _mangle_global_identifier(name: str) -> str:
  82. """
  83. Module level identifiers are mangled and aliased so that they can be disambiguated
  84. from fields/enum variants with the same name within the file.
  85. Eg:
  86. Enum variant `Name` or message field `Name` might conflict with a top level
  87. message or enum named `Name`, so mangle it with a global___ prefix for
  88. internal references. Note that this doesn't affect inner enums/messages
  89. because they get fuly qualified when referenced within a file"""
  90. return f"global___{name}"
  91. class Descriptors(object):
  92. def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
  93. files = {f.name: f for f in request.proto_file}
  94. to_generate = {n: files[n] for n in request.file_to_generate}
  95. self.files: Dict[str, d.FileDescriptorProto] = files
  96. self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
  97. self.messages: Dict[str, d.DescriptorProto] = {}
  98. self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
  99. def _add_enums(
  100. enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
  101. prefix: str,
  102. _fd: d.FileDescriptorProto,
  103. ) -> None:
  104. for enum in enums:
  105. self.message_to_fd[prefix + enum.name] = _fd
  106. self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
  107. def _add_messages(
  108. messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
  109. prefix: str,
  110. _fd: d.FileDescriptorProto,
  111. ) -> None:
  112. for message in messages:
  113. self.messages[prefix + message.name] = message
  114. self.message_to_fd[prefix + message.name] = _fd
  115. sub_prefix = prefix + message.name + "."
  116. _add_messages(message.nested_type, sub_prefix, _fd)
  117. _add_enums(message.enum_type, sub_prefix, _fd)
  118. for fd in request.proto_file:
  119. start_prefix = "." + fd.package + "." if fd.package else "."
  120. _add_messages(fd.message_type, start_prefix, fd)
  121. _add_enums(fd.enum_type, start_prefix, fd)
  122. class PkgWriter(object):
  123. """Writes a single pyi file"""
  124. def __init__(
  125. self,
  126. fd: d.FileDescriptorProto,
  127. descriptors: Descriptors,
  128. readable_stubs: bool,
  129. relax_strict_optional_primitives: bool,
  130. grpc: bool,
  131. ) -> None:
  132. self.fd = fd
  133. self.descriptors = descriptors
  134. self.readable_stubs = readable_stubs
  135. self.relax_strict_optional_primitives = relax_strict_optional_primitives
  136. self.grpc = grpc
  137. self.lines: List[str] = []
  138. self.indent = ""
  139. # Set of {x}, where {x} corresponds to to `import {x}`
  140. self.imports: Set[str] = set()
  141. # dictionary of x->(y,z) for `from {x} import {y} as {z}`
  142. # if {z} is None, then it shortens to `from {x} import {y}`
  143. self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set)
  144. # Comments
  145. self.source_code_info_by_scl = {
  146. tuple(location.path): location for location in fd.source_code_info.location
  147. }
  148. def _import(self, path: str, name: str) -> str:
  149. """Imports a stdlib path and returns a handle to it
  150. eg. self._import("typing", "Optional") -> "Optional"
  151. """
  152. imp = path.replace("/", ".")
  153. if self.readable_stubs:
  154. self.from_imports[imp].add((name, None))
  155. return name
  156. else:
  157. self.imports.add(imp)
  158. return imp + "." + name
  159. def _import_message(self, name: str) -> str:
  160. """Import a referenced message and return a handle"""
  161. message_fd = self.descriptors.message_to_fd[name]
  162. assert message_fd.name.endswith(".proto")
  163. # Strip off package name
  164. if message_fd.package:
  165. assert name.startswith("." + message_fd.package + ".")
  166. name = name[len("." + message_fd.package + ".") :]
  167. else:
  168. assert name.startswith(".")
  169. name = name[1:]
  170. # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
  171. split = name.split(".")
  172. for i, part in enumerate(split):
  173. if part in PYTHON_RESERVED:
  174. split[i] = "_r_" + part
  175. name = ".".join(split)
  176. # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
  177. if not self.grpc and message_fd.name == self.fd.name:
  178. return name if self.readable_stubs else _mangle_global_identifier(name)
  179. # Not in file. Must import
  180. # Python generated code ignores proto packages, so the only relevant factor is
  181. # whether it is in the file or not.
  182. import_name = self._import(
  183. message_fd.name[:-6].replace("-", "_") + "_pb2", split[0]
  184. )
  185. remains = ".".join(split[1:])
  186. if not remains:
  187. return import_name
  188. # remains could either be a direct import of a nested enum or message
  189. # from another package.
  190. return import_name + "." + remains
  191. def _builtin(self, name: str) -> str:
  192. return self._import("builtins", name)
  193. @contextmanager
  194. def _indent(self) -> Iterator[None]:
  195. self.indent = self.indent + " "
  196. yield
  197. self.indent = self.indent[:-4]
  198. def _write_line(self, line: str, *args: Any) -> None:
  199. if args:
  200. line = line.format(*args)
  201. if line == "":
  202. self.lines.append(line)
  203. else:
  204. self.lines.append(self.indent + line)
  205. def _break_text(self, text_block: str) -> List[str]:
  206. if text_block == "":
  207. return []
  208. return [
  209. l[1:] if l.startswith(" ") else l for l in text_block.rstrip().split("\n")
  210. ]
  211. def _has_comments(self, scl: SourceCodeLocation) -> bool:
  212. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  213. return sci_loc is not None and bool(
  214. sci_loc.leading_detached_comments
  215. or sci_loc.leading_comments
  216. or sci_loc.trailing_comments
  217. )
  218. def _write_comments(self, scl: SourceCodeLocation) -> bool:
  219. """Return true if any comments were written"""
  220. if not self._has_comments(scl):
  221. return False
  222. sci_loc = self.source_code_info_by_scl.get(tuple(scl))
  223. assert sci_loc is not None
  224. lines = []
  225. for leading_detached_comment in sci_loc.leading_detached_comments:
  226. lines.extend(self._break_text(leading_detached_comment))
  227. lines.append("")
  228. if sci_loc.leading_comments is not None:
  229. lines.extend(self._break_text(sci_loc.leading_comments))
  230. # Trailing comments also go in the header - to make sure it gets into the docstring
  231. if sci_loc.trailing_comments is not None:
  232. lines.extend(self._break_text(sci_loc.trailing_comments))
  233. lines = [
  234. # Escape triple-quotes that would otherwise end the docstring early.
  235. line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
  236. for line in lines
  237. ]
  238. if len(lines) == 1:
  239. line = lines[0]
  240. if line.endswith(('"', "\\")):
  241. # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
  242. # insert some whitespace to separate it from the closing quotes.
  243. # This is not necessary with multiline comments
  244. # because in that case we always insert a newline before the trailing triple-quotes.
  245. line = line + " "
  246. self._write_line(f'"""{line}"""')
  247. else:
  248. for i, line in enumerate(lines):
  249. if i == 0:
  250. self._write_line(f'"""{line}')
  251. else:
  252. self._write_line(f"{line}")
  253. self._write_line('"""')
  254. return True
  255. def write_enum_values(
  256. self,
  257. values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
  258. value_type: str,
  259. scl_prefix: SourceCodeLocation,
  260. ) -> None:
  261. for i, val in values:
  262. if val.name in PYTHON_RESERVED:
  263. continue
  264. scl = scl_prefix + [i]
  265. self._write_line(
  266. f"{val.name}: {value_type} # {val.number}",
  267. )
  268. if self._write_comments(scl):
  269. self._write_line("") # Extra newline to separate
  270. def write_module_attributes(self) -> None:
  271. l = self._write_line
  272. fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
  273. l(f"DESCRIPTOR: {fd_type}")
  274. l("")
  275. def write_enums(
  276. self,
  277. enums: Iterable[d.EnumDescriptorProto],
  278. prefix: str,
  279. scl_prefix: SourceCodeLocation,
  280. ) -> None:
  281. l = self._write_line
  282. for i, enum in enumerate(enums):
  283. class_name = (
  284. enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
  285. )
  286. value_type_fq = prefix + class_name + ".ValueType"
  287. enum_helper_class = "_" + enum.name
  288. value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
  289. etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
  290. scl = scl_prefix + [i]
  291. l(f"class {enum_helper_class}:")
  292. with self._indent():
  293. l(
  294. "ValueType = {}('ValueType', {})",
  295. self._import("typing", "NewType"),
  296. self._builtin("int"),
  297. )
  298. # Alias to the classic shorter definition "V"
  299. l("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
  300. l(
  301. "class {}({}[{}], {}):",
  302. etw_helper_class,
  303. self._import(
  304. "google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"
  305. ),
  306. value_type_helper_fq,
  307. self._builtin("type"),
  308. )
  309. with self._indent():
  310. ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
  311. l(f"DESCRIPTOR: {ed}")
  312. self.write_enum_values(
  313. [
  314. (i, v)
  315. for i, v in enumerate(enum.value)
  316. if v.name not in PROTO_ENUM_RESERVED
  317. ],
  318. value_type_helper_fq,
  319. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  320. )
  321. l(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
  322. with self._indent():
  323. self._write_comments(scl)
  324. l("pass")
  325. l("")
  326. self.write_enum_values(
  327. enumerate(enum.value),
  328. value_type_fq,
  329. scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
  330. )
  331. if prefix == "" and not self.readable_stubs:
  332. l(f"{_mangle_global_identifier(class_name)} = {class_name}")
  333. l("")
  334. l("")
  335. def write_messages(
  336. self,
  337. messages: Iterable[d.DescriptorProto],
  338. prefix: str,
  339. scl_prefix: SourceCodeLocation,
  340. ) -> None:
  341. l = self._write_line
  342. for i, desc in enumerate(messages):
  343. qualified_name = prefix + desc.name
  344. # Reproduce some hardcoded logic from the protobuf implementation - where
  345. # some specific "well_known_types" generated protos to have additional
  346. # base classes
  347. addl_base = ""
  348. if self.fd.package + "." + desc.name in WKTBASES:
  349. # chop off the .proto - and import the well known type
  350. # eg `from google.protobuf.duration import Duration`
  351. well_known_type = WKTBASES[self.fd.package + "." + desc.name]
  352. addl_base = ", " + self._import(
  353. "google.protobuf.internal.well_known_types",
  354. well_known_type.__name__,
  355. )
  356. class_name = (
  357. desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
  358. )
  359. message_class = self._import("google.protobuf.message", "Message")
  360. l(f"class {class_name}({message_class}{addl_base}):")
  361. with self._indent():
  362. scl = scl_prefix + [i]
  363. self._write_comments(scl)
  364. desc_type = self._import("google.protobuf.descriptor", "Descriptor")
  365. l(f"DESCRIPTOR: {desc_type}")
  366. # Nested enums/messages
  367. self.write_enums(
  368. desc.enum_type,
  369. qualified_name + ".",
  370. scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
  371. )
  372. self.write_messages(
  373. desc.nested_type,
  374. qualified_name + ".",
  375. scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
  376. )
  377. # integer constants for field numbers
  378. for f in desc.field:
  379. l(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  380. for idx, field in enumerate(desc.field):
  381. if field.name in PYTHON_RESERVED:
  382. continue
  383. field_type = self.python_type(field)
  384. if (
  385. is_scalar(field)
  386. and field.label != d.FieldDescriptorProto.LABEL_REPEATED
  387. ):
  388. # Scalar non repeated fields are r/w
  389. l(f"{field.name}: {field_type}")
  390. if self._write_comments(
  391. scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
  392. ):
  393. l("")
  394. else:
  395. # r/o Getters for non-scalar fields and scalar-repeated fields
  396. scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
  397. l("@property")
  398. body = " ..." if not self._has_comments(scl_field) else ""
  399. l(f"def {field.name}(self) -> {field_type}:{body}")
  400. if self._has_comments(scl_field):
  401. with self._indent():
  402. self._write_comments(scl_field)
  403. l("pass")
  404. self.write_extensions(
  405. desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER]
  406. )
  407. # Constructor
  408. if any(f.name == "self" for f in desc.field):
  409. l("# pyright: reportSelfClsParameterName=false")
  410. l(f"def __init__(self_,")
  411. else:
  412. l(f"def __init__(self,")
  413. with self._indent():
  414. constructor_fields = [
  415. f for f in desc.field if f.name not in PYTHON_RESERVED
  416. ]
  417. if len(constructor_fields) > 0:
  418. # Only positional args allowed
  419. # See https://github.com/dropbox/mypy-protobuf/issues/71
  420. l("*,")
  421. for field in constructor_fields:
  422. field_type = self.python_type(field, generic_container=True)
  423. if (
  424. self.fd.syntax == "proto3"
  425. and is_scalar(field)
  426. and field.label != d.FieldDescriptorProto.LABEL_REPEATED
  427. and not self.relax_strict_optional_primitives
  428. and not field.proto3_optional
  429. ):
  430. l(f"{field.name}: {field_type} = ...,")
  431. else:
  432. opt = self._import("typing", "Optional")
  433. l(f"{field.name}: {opt}[{field_type}] = ...,")
  434. l(") -> None: ...")
  435. self.write_stringly_typed_fields(desc)
  436. if prefix == "" and not self.readable_stubs:
  437. l(f"{_mangle_global_identifier(class_name)} = {class_name}")
  438. l("")
  439. def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
  440. """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
  441. l = self._write_line
  442. # HasField, ClearField, WhichOneof accepts both bytes/str
  443. # HasField only supports singular. ClearField supports repeated as well
  444. # In proto3, HasField only supports message fields and optional fields
  445. # HasField always supports oneof fields
  446. hf_fields = [
  447. f.name
  448. for f in desc.field
  449. if f.HasField("oneof_index")
  450. or (
  451. f.label != d.FieldDescriptorProto.LABEL_REPEATED
  452. and (
  453. self.fd.syntax != "proto3"
  454. or f.type == d.FieldDescriptorProto.TYPE_MESSAGE
  455. or f.proto3_optional
  456. )
  457. )
  458. ]
  459. cf_fields = [f.name for f in desc.field]
  460. wo_fields = {
  461. oneof.name: [
  462. f.name
  463. for f in desc.field
  464. if f.HasField("oneof_index") and f.oneof_index == idx
  465. ]
  466. for idx, oneof in enumerate(desc.oneof_decl)
  467. }
  468. hf_fields.extend(wo_fields.keys())
  469. cf_fields.extend(wo_fields.keys())
  470. hf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in hf_fields))
  471. cf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in cf_fields))
  472. if not hf_fields and not cf_fields and not wo_fields:
  473. return
  474. if hf_fields:
  475. l(
  476. "def HasField(self, field_name: {}[{}]) -> {}: ...",
  477. self._import("typing_extensions", "Literal"),
  478. hf_fields_text,
  479. self._builtin("bool"),
  480. )
  481. if cf_fields:
  482. l(
  483. "def ClearField(self, field_name: {}[{}]) -> None: ...",
  484. self._import("typing_extensions", "Literal"),
  485. cf_fields_text,
  486. )
  487. for wo_field, members in sorted(wo_fields.items()):
  488. if len(wo_fields) > 1:
  489. l("@{}", self._import("typing", "overload"))
  490. l(
  491. "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}[{}]]: ...",
  492. self._import("typing_extensions", "Literal"),
  493. # Accepts both str and bytes
  494. f'"{wo_field}",b"{wo_field}"',
  495. self._import("typing", "Optional"),
  496. self._import("typing_extensions", "Literal"),
  497. # Returns `str`
  498. ",".join(f'"{m}"' for m in members),
  499. )
  500. def write_extensions(
  501. self,
  502. extensions: Sequence[d.FieldDescriptorProto],
  503. scl_prefix: SourceCodeLocation,
  504. ) -> None:
  505. l = self._write_line
  506. for ext in extensions:
  507. l(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
  508. for i, ext in enumerate(extensions):
  509. scl = scl_prefix + [i]
  510. l(
  511. "{}: {}[{}, {}]",
  512. ext.name,
  513. self._import(
  514. "google.protobuf.internal.extension_dict",
  515. "_ExtensionFieldDescriptor",
  516. ),
  517. self._import_message(ext.extendee),
  518. self.python_type(ext),
  519. )
  520. self._write_comments(scl)
  521. l("")
  522. def write_methods(
  523. self,
  524. service: d.ServiceDescriptorProto,
  525. class_name: str,
  526. is_abstract: bool,
  527. scl_prefix: SourceCodeLocation,
  528. ) -> None:
  529. l = self._write_line
  530. l(
  531. "DESCRIPTOR: {}",
  532. self._import("google.protobuf.descriptor", "ServiceDescriptor"),
  533. )
  534. methods = [
  535. (i, m)
  536. for i, m in enumerate(service.method)
  537. if m.name not in PYTHON_RESERVED
  538. ]
  539. if not methods:
  540. l("pass")
  541. for i, method in methods:
  542. if is_abstract:
  543. l("@{}", self._import("abc", "abstractmethod"))
  544. l(f"def {method.name}(")
  545. with self._indent():
  546. l(f"inst: {class_name},")
  547. l(
  548. "rpc_controller: {},",
  549. self._import("google.protobuf.service", "RpcController"),
  550. )
  551. l("request: {},", self._import_message(method.input_type))
  552. l(
  553. "callback: {}[{}[[{}], None]]{},",
  554. self._import("typing", "Optional"),
  555. self._import("typing", "Callable"),
  556. self._import_message(method.output_type),
  557. "" if is_abstract else " = None",
  558. )
  559. scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  560. l(
  561. ") -> {}[{}]:{}",
  562. self._import("concurrent.futures", "Future"),
  563. self._import_message(method.output_type),
  564. " ..." if not self._has_comments(scl_method) else "",
  565. )
  566. if self._has_comments(scl_method):
  567. with self._indent():
  568. self._write_comments(scl_method)
  569. l("pass")
  570. def write_services(
  571. self,
  572. services: Iterable[d.ServiceDescriptorProto],
  573. scl_prefix: SourceCodeLocation,
  574. ) -> None:
  575. l = self._write_line
  576. for i, service in enumerate(services):
  577. scl = scl_prefix + [i]
  578. class_name = (
  579. service.name
  580. if service.name not in PYTHON_RESERVED
  581. else "_r_" + service.name
  582. )
  583. # The service definition interface
  584. l(
  585. "class {}({}, metaclass={}):",
  586. class_name,
  587. self._import("google.protobuf.service", "Service"),
  588. self._import("abc", "ABCMeta"),
  589. )
  590. with self._indent():
  591. self._write_comments(scl)
  592. self.write_methods(
  593. service, class_name, is_abstract=True, scl_prefix=scl
  594. )
  595. # The stub client
  596. stub_class_name = service.name + "_Stub"
  597. l("class {}({}):", stub_class_name, class_name)
  598. with self._indent():
  599. self._write_comments(scl)
  600. l(
  601. "def __init__(self, rpc_channel: {}) -> None: ...",
  602. self._import("google.protobuf.service", "RpcChannel"),
  603. )
  604. self.write_methods(
  605. service, stub_class_name, is_abstract=False, scl_prefix=scl
  606. )
  607. def _import_casttype(self, casttype: str) -> str:
  608. split = casttype.split(".")
  609. assert (
  610. len(split) == 2
  611. ), "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
  612. pkg = split[0].replace("/", ".")
  613. return self._import(pkg, split[1])
  614. def _map_key_value_types(
  615. self,
  616. map_field: d.FieldDescriptorProto,
  617. key_field: d.FieldDescriptorProto,
  618. value_field: d.FieldDescriptorProto,
  619. ) -> Tuple[str, str]:
  620. key_casttype = map_field.options.Extensions[extensions_pb2.keytype]
  621. ktype = (
  622. self._import_casttype(key_casttype)
  623. if key_casttype
  624. else self.python_type(key_field)
  625. )
  626. value_casttype = map_field.options.Extensions[extensions_pb2.valuetype]
  627. vtype = (
  628. self._import_casttype(value_casttype)
  629. if value_casttype
  630. else self.python_type(value_field)
  631. )
  632. return ktype, vtype
  633. def _callable_type(self, method: d.MethodDescriptorProto) -> str:
  634. if method.client_streaming:
  635. if method.server_streaming:
  636. return self._import("grpc", "StreamStreamMultiCallable")
  637. else:
  638. return self._import("grpc", "StreamUnaryMultiCallable")
  639. else:
  640. if method.server_streaming:
  641. return self._import("grpc", "UnaryStreamMultiCallable")
  642. else:
  643. return self._import("grpc", "UnaryUnaryMultiCallable")
  644. def _input_type(
  645. self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
  646. ) -> str:
  647. result = self._import_message(method.input_type)
  648. if use_stream_iterator and method.client_streaming:
  649. result = f"{self._import('typing', 'Iterator')}[{result}]"
  650. return result
  651. def _output_type(
  652. self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
  653. ) -> str:
  654. result = self._import_message(method.output_type)
  655. if use_stream_iterator and method.server_streaming:
  656. result = f"{self._import('typing', 'Iterator')}[{result}]"
  657. return result
  658. def write_grpc_methods(
  659. self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
  660. ) -> None:
  661. l = self._write_line
  662. methods = [
  663. (i, m)
  664. for i, m in enumerate(service.method)
  665. if m.name not in PYTHON_RESERVED
  666. ]
  667. if not methods:
  668. l("pass")
  669. l("")
  670. for i, method in methods:
  671. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  672. l("@{}", self._import("abc", "abstractmethod"))
  673. l("def {}(self,", method.name)
  674. with self._indent():
  675. input_name = (
  676. "request_iterator" if method.client_streaming else "request"
  677. )
  678. input_type = self._input_type(method)
  679. l(f"{input_name}: {input_type},")
  680. l("context: {},", self._import("grpc", "ServicerContext"))
  681. l(
  682. ") -> {}:{}",
  683. self._output_type(method),
  684. " ..." if not self._has_comments(scl) else "",
  685. ),
  686. if self._has_comments(scl):
  687. with self._indent():
  688. self._write_comments(scl)
  689. l("pass")
  690. l("")
  691. def write_grpc_stub_methods(
  692. self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
  693. ) -> None:
  694. l = self._write_line
  695. methods = [
  696. (i, m)
  697. for i, m in enumerate(service.method)
  698. if m.name not in PYTHON_RESERVED
  699. ]
  700. if not methods:
  701. l("pass")
  702. l("")
  703. for i, method in methods:
  704. scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
  705. l("{}: {}[", method.name, self._callable_type(method))
  706. with self._indent():
  707. l("{},", self._input_type(method, False))
  708. l("{}]", self._output_type(method, False))
  709. self._write_comments(scl)
  710. l("")
  711. def write_grpc_services(
  712. self,
  713. services: Iterable[d.ServiceDescriptorProto],
  714. scl_prefix: SourceCodeLocation,
  715. ) -> None:
  716. l = self._write_line
  717. for i, service in enumerate(services):
  718. if service.name in PYTHON_RESERVED:
  719. continue
  720. scl = scl_prefix + [i]
  721. # The stub client
  722. l(f"class {service.name}Stub:")
  723. with self._indent():
  724. self._write_comments(scl)
  725. l(
  726. "def __init__(self, channel: {}) -> None: ...",
  727. self._import("grpc", "Channel"),
  728. )
  729. self.write_grpc_stub_methods(service, scl)
  730. l("")
  731. # The service definition interface
  732. l(
  733. "class {}Servicer(metaclass={}):",
  734. service.name,
  735. self._import("abc", "ABCMeta"),
  736. )
  737. with self._indent():
  738. self._write_comments(scl)
  739. self.write_grpc_methods(service, scl)
  740. l("")
  741. l(
  742. "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
  743. service.name,
  744. service.name,
  745. self._import("grpc", "Server"),
  746. )
  747. l("")
  748. def python_type(
  749. self, field: d.FieldDescriptorProto, generic_container: bool = False
  750. ) -> str:
  751. """
  752. generic_container
  753. if set, type the field with generic interfaces. Eg.
  754. - Iterable[int] rather than RepeatedScalarFieldContainer[int]
  755. - Mapping[k, v] rather than MessageMap[k, v]
  756. Can be useful for input types (eg constructor)
  757. """
  758. casttype = field.options.Extensions[extensions_pb2.casttype]
  759. if casttype:
  760. return self._import_casttype(casttype)
  761. mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
  762. d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
  763. d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
  764. d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
  765. d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
  766. d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
  767. d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
  768. d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
  769. d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
  770. d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
  771. d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
  772. d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
  773. d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
  774. d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
  775. d.FieldDescriptorProto.TYPE_STRING: lambda: self._import("typing", "Text"),
  776. d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
  777. d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(
  778. field.type_name + ".ValueType"
  779. ),
  780. d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(
  781. field.type_name
  782. ),
  783. d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(
  784. field.type_name
  785. ),
  786. }
  787. assert field.type in mapping, "Unrecognized type: " + repr(field.type)
  788. field_type = mapping[field.type]()
  789. # For non-repeated fields, we're done!
  790. if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
  791. return field_type
  792. # Scalar repeated fields go in RepeatedScalarFieldContainer
  793. if is_scalar(field):
  794. container = (
  795. self._import("typing", "Iterable")
  796. if generic_container
  797. else self._import(
  798. "google.protobuf.internal.containers",
  799. "RepeatedScalarFieldContainer",
  800. )
  801. )
  802. return f"{container}[{field_type}]"
  803. # non-scalar repeated map fields go in ScalarMap/MessageMap
  804. msg = self.descriptors.messages[field.type_name]
  805. if msg.options.map_entry:
  806. # map generates a special Entry wrapper message
  807. if generic_container:
  808. container = self._import("typing", "Mapping")
  809. elif is_scalar(msg.field[1]):
  810. container = self._import(
  811. "google.protobuf.internal.containers", "ScalarMap"
  812. )
  813. else:
  814. container = self._import(
  815. "google.protobuf.internal.containers", "MessageMap"
  816. )
  817. ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
  818. return f"{container}[{ktype}, {vtype}]"
  819. # non-scalar repetated fields go in RepeatedCompositeFieldContainer
  820. container = (
  821. self._import("typing", "Iterable")
  822. if generic_container
  823. else self._import(
  824. "google.protobuf.internal.containers",
  825. "RepeatedCompositeFieldContainer",
  826. )
  827. )
  828. return f"{container}[{field_type}]"
  829. def write(self) -> str:
  830. for reexport_idx in self.fd.public_dependency:
  831. reexport_file = self.fd.dependency[reexport_idx]
  832. reexport_fd = self.descriptors.files[reexport_file]
  833. reexport_imp = (
  834. reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
  835. )
  836. names = (
  837. [m.name for m in reexport_fd.message_type]
  838. + [m.name for m in reexport_fd.enum_type]
  839. + [v.name for m in reexport_fd.enum_type for v in m.value]
  840. + [m.name for m in reexport_fd.extension]
  841. )
  842. if reexport_fd.options.py_generic_services:
  843. names.extend(m.name for m in reexport_fd.service)
  844. if names:
  845. # n,n to force a reexport (from x import y as y)
  846. self.from_imports[reexport_imp].update((n, n) for n in names)
  847. import_lines = []
  848. for pkg in sorted(self.imports):
  849. import_lines.append(f"import {pkg}")
  850. for pkg, items in sorted(self.from_imports.items()):
  851. import_lines.append(f"from {pkg} import (")
  852. for (name, reexport_name) in sorted(items):
  853. if reexport_name is None:
  854. import_lines.append(f" {name},")
  855. else:
  856. import_lines.append(f" {name} as {reexport_name},")
  857. import_lines.append(")\n")
  858. import_lines.append("")
  859. return "\n".join(import_lines + self.lines)
  860. def is_scalar(fd: d.FieldDescriptorProto) -> bool:
  861. return not (
  862. fd.type == d.FieldDescriptorProto.TYPE_MESSAGE
  863. or fd.type == d.FieldDescriptorProto.TYPE_GROUP
  864. )
  865. def generate_mypy_stubs(
  866. descriptors: Descriptors,
  867. response: plugin_pb2.CodeGeneratorResponse,
  868. quiet: bool,
  869. readable_stubs: bool,
  870. relax_strict_optional_primitives: bool,
  871. ) -> None:
  872. for name, fd in descriptors.to_generate.items():
  873. pkg_writer = PkgWriter(
  874. fd,
  875. descriptors,
  876. readable_stubs,
  877. relax_strict_optional_primitives,
  878. grpc=False,
  879. )
  880. pkg_writer.write_module_attributes()
  881. pkg_writer.write_enums(
  882. fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER]
  883. )
  884. pkg_writer.write_messages(
  885. fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER]
  886. )
  887. pkg_writer.write_extensions(
  888. fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER]
  889. )
  890. if fd.options.py_generic_services:
  891. pkg_writer.write_services(
  892. fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
  893. )
  894. assert name == fd.name
  895. assert fd.name.endswith(".proto")
  896. output = response.file.add()
  897. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
  898. output.content = HEADER + pkg_writer.write()
  899. def generate_mypy_grpc_stubs(
  900. descriptors: Descriptors,
  901. response: plugin_pb2.CodeGeneratorResponse,
  902. quiet: bool,
  903. readable_stubs: bool,
  904. relax_strict_optional_primitives: bool,
  905. ) -> None:
  906. for name, fd in descriptors.to_generate.items():
  907. pkg_writer = PkgWriter(
  908. fd,
  909. descriptors,
  910. readable_stubs,
  911. relax_strict_optional_primitives,
  912. grpc=True,
  913. )
  914. pkg_writer.write_grpc_services(
  915. fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
  916. )
  917. assert name == fd.name
  918. assert fd.name.endswith(".proto")
  919. output = response.file.add()
  920. output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
  921. output.content = HEADER + pkg_writer.write()
  922. @contextmanager
  923. def code_generation() -> Iterator[
  924. Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
  925. ]:
  926. if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
  927. print("mypy-protobuf " + __version__)
  928. sys.exit(0)
  929. # Read request message from stdin
  930. data = sys.stdin.buffer.read()
  931. # Parse request
  932. request = plugin_pb2.CodeGeneratorRequest()
  933. request.ParseFromString(data)
  934. # Create response
  935. response = plugin_pb2.CodeGeneratorResponse()
  936. # Declare support for optional proto3 fields
  937. response.supported_features |= (
  938. plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
  939. )
  940. yield request, response
  941. # Serialise response message
  942. output = response.SerializeToString()
  943. # Write to stdout
  944. sys.stdout.buffer.write(output)
  945. def main() -> None:
  946. # Generate mypy
  947. with code_generation() as (request, response):
  948. generate_mypy_stubs(
  949. Descriptors(request),
  950. response,
  951. "quiet" in request.parameter,
  952. "readable_stubs" in request.parameter,
  953. "relax_strict_optional_primitives" in request.parameter,
  954. )
  955. def grpc() -> None:
  956. # Generate grpc mypy
  957. with code_generation() as (request, response):
  958. generate_mypy_grpc_stubs(
  959. Descriptors(request),
  960. response,
  961. "quiet" in request.parameter,
  962. "readable_stubs" in request.parameter,
  963. "relax_strict_optional_primitives" in request.parameter,
  964. )
  965. if __name__ == "__main__":
  966. main()