12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022 |
- #!/usr/bin/env python
- """Protoc Plugin to generate mypy stubs."""
- from __future__ import annotations
- import sys
- from collections import defaultdict
- from contextlib import contextmanager
- from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Sequence,
- Tuple,
- )
- import google.protobuf.descriptor_pb2 as d
- from google.protobuf.compiler import plugin_pb2 as plugin_pb2
- from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
- from google.protobuf.internal.well_known_types import WKTBASES
- from . import extensions_pb2
- __version__ = "3.3.0"
- # SourceCodeLocation is defined by `message Location` here
- # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
- SourceCodeLocation = List[int]
- # So phabricator doesn't think mypy_protobuf.py is generated
- GENERATED = "@ge" + "nerated"
- HEADER = f"""
- {GENERATED} by mypy-protobuf. Do not edit manually!
- isort:skip_file
- """
- # See https://github.com/nipunn1313/mypy-protobuf/issues/73 for details
- PYTHON_RESERVED = {
- "False",
- "None",
- "True",
- "and",
- "as",
- "async",
- "await",
- "assert",
- "break",
- "class",
- "continue",
- "def",
- "del",
- "elif",
- "else",
- "except",
- "finally",
- "for",
- "from",
- "global",
- "if",
- "import",
- "in",
- "is",
- "lambda",
- "nonlocal",
- "not",
- "or",
- "pass",
- "raise",
- "return",
- "try",
- "while",
- "with",
- "yield",
- }
- PROTO_ENUM_RESERVED = {
- "Name",
- "Value",
- "keys",
- "values",
- "items",
- }
- def _mangle_global_identifier(name: str) -> str:
- """
- Module level identifiers are mangled and aliased so that they can be disambiguated
- from fields/enum variants with the same name within the file.
- Eg:
- Enum variant `Name` or message field `Name` might conflict with a top level
- message or enum named `Name`, so mangle it with a global___ prefix for
- internal references. Note that this doesn't affect inner enums/messages
- because they get fuly qualified when referenced within a file"""
- return f"global___{name}"
- class Descriptors(object):
- def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
- files = {f.name: f for f in request.proto_file}
- to_generate = {n: files[n] for n in request.file_to_generate}
- self.files: Dict[str, d.FileDescriptorProto] = files
- self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
- self.messages: Dict[str, d.DescriptorProto] = {}
- self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
- def _add_enums(
- enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for enum in enums:
- self.message_to_fd[prefix + enum.name] = _fd
- self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
- def _add_messages(
- messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for message in messages:
- self.messages[prefix + message.name] = message
- self.message_to_fd[prefix + message.name] = _fd
- sub_prefix = prefix + message.name + "."
- _add_messages(message.nested_type, sub_prefix, _fd)
- _add_enums(message.enum_type, sub_prefix, _fd)
- for fd in request.proto_file:
- start_prefix = "." + fd.package + "." if fd.package else "."
- _add_messages(fd.message_type, start_prefix, fd)
- _add_enums(fd.enum_type, start_prefix, fd)
- class PkgWriter(object):
- """Writes a single pyi file"""
- def __init__(
- self,
- fd: d.FileDescriptorProto,
- descriptors: Descriptors,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- grpc: bool,
- ) -> None:
- self.fd = fd
- self.descriptors = descriptors
- self.readable_stubs = readable_stubs
- self.relax_strict_optional_primitives = relax_strict_optional_primitives
- self.grpc = grpc
- self.lines: List[str] = []
- self.indent = ""
- # Set of {x}, where {x} corresponds to to `import {x}`
- self.imports: Set[str] = set()
- # dictionary of x->(y,z) for `from {x} import {y} as {z}`
- # if {z} is None, then it shortens to `from {x} import {y}`
- self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set)
- self.typing_extensions_min: Optional[Tuple[int, int]] = None
- # Comments
- self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location}
- def _import(self, path: str, name: str) -> str:
- """Imports a stdlib path and returns a handle to it
- eg. self._import("typing", "Literal") -> "Literal"
- """
- if path == "typing_extensions":
- stabilization = {
- "Literal": (3, 8),
- "TypeAlias": (3, 10),
- }
- assert name in stabilization
- if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]:
- self.typing_extensions_min = stabilization[name]
- return "typing_extensions." + name
- imp = path.replace("/", ".")
- if self.readable_stubs:
- self.from_imports[imp].add((name, None))
- return name
- else:
- self.imports.add(imp)
- return imp + "." + name
- def _import_message(self, name: str) -> str:
- """Import a referenced message and return a handle"""
- message_fd = self.descriptors.message_to_fd[name]
- assert message_fd.name.endswith(".proto")
- # Strip off package name
- if message_fd.package:
- assert name.startswith("." + message_fd.package + ".")
- name = name[len("." + message_fd.package + ".") :]
- else:
- assert name.startswith(".")
- name = name[1:]
- # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
- split = name.split(".")
- for i, part in enumerate(split):
- if part in PYTHON_RESERVED:
- split[i] = "_r_" + part
- name = ".".join(split)
- # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
- if not self.grpc and message_fd.name == self.fd.name:
- return name if self.readable_stubs else _mangle_global_identifier(name)
- # Not in file. Must import
- # Python generated code ignores proto packages, so the only relevant factor is
- # whether it is in the file or not.
- import_name = self._import(message_fd.name[:-6].replace("-", "_") + "_pb2", split[0])
- remains = ".".join(split[1:])
- if not remains:
- return import_name
- # remains could either be a direct import of a nested enum or message
- # from another package.
- return import_name + "." + remains
- def _builtin(self, name: str) -> str:
- return self._import("builtins", name)
- @contextmanager
- def _indent(self) -> Iterator[None]:
- self.indent = self.indent + " "
- yield
- self.indent = self.indent[:-4]
- def _write_line(self, line: str, *args: Any) -> None:
- if args:
- line = line.format(*args)
- if line == "":
- self.lines.append(line)
- else:
- self.lines.append(self.indent + line)
- def _break_text(self, text_block: str) -> List[str]:
- if text_block == "":
- return []
- return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")]
- def _has_comments(self, scl: SourceCodeLocation) -> bool:
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments)
- def _write_comments(self, scl: SourceCodeLocation) -> bool:
- """Return true if any comments were written"""
- if not self._has_comments(scl):
- return False
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- assert sci_loc is not None
- leading_detached_lines = []
- leading_lines = []
- trailing_lines = []
- for leading_detached_comment in sci_loc.leading_detached_comments:
- leading_detached_lines = self._break_text(leading_detached_comment)
- if sci_loc.leading_comments is not None:
- leading_lines = self._break_text(sci_loc.leading_comments)
- # Trailing comments also go in the header - to make sure it gets into the docstring
- if sci_loc.trailing_comments is not None:
- trailing_lines = self._break_text(sci_loc.trailing_comments)
- lines = leading_detached_lines
- if leading_detached_lines and (leading_lines or trailing_lines):
- lines.append("")
- lines.extend(leading_lines)
- lines.extend(trailing_lines)
- lines = [
- # Escape triple-quotes that would otherwise end the docstring early.
- line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
- for line in lines
- ]
- if len(lines) == 1:
- line = lines[0]
- if line.endswith(('"', "\\")):
- # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
- # insert some whitespace to separate it from the closing quotes.
- # This is not necessary with multiline comments
- # because in that case we always insert a newline before the trailing triple-quotes.
- line = line + " "
- self._write_line(f'"""{line}"""')
- else:
- for i, line in enumerate(lines):
- if i == 0:
- self._write_line(f'"""{line}')
- else:
- self._write_line(f"{line}")
- self._write_line('"""')
- return True
- def write_enum_values(
- self,
- values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
- value_type: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- for i, val in values:
- if val.name in PYTHON_RESERVED:
- continue
- scl = scl_prefix + [i]
- self._write_line(
- f"{val.name}: {value_type} # {val.number}",
- )
- self._write_comments(scl)
- def write_module_attributes(self) -> None:
- wl = self._write_line
- fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
- wl(f"DESCRIPTOR: {fd_type}")
- wl("")
- def write_enums(
- self,
- enums: Iterable[d.EnumDescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- for i, enum in enumerate(enums):
- class_name = enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
- value_type_fq = prefix + class_name + ".ValueType"
- enum_helper_class = "_" + enum.name
- value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
- etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
- scl = scl_prefix + [i]
- wl(f"class {enum_helper_class}:")
- with self._indent():
- wl(
- 'ValueType = {}("ValueType", {})',
- self._import("typing", "NewType"),
- self._builtin("int"),
- )
- # Alias to the classic shorter definition "V"
- wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
- wl("")
- wl(
- "class {}({}[{}], {}): # noqa: F821",
- etw_helper_class,
- self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"),
- value_type_helper_fq,
- self._builtin("type"),
- )
- with self._indent():
- ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
- wl(f"DESCRIPTOR: {ed}")
- self.write_enum_values(
- [(i, v) for i, v in enumerate(enum.value) if v.name not in PROTO_ENUM_RESERVED],
- value_type_helper_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- wl("")
- if self._has_comments(scl):
- wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
- with self._indent():
- self._write_comments(scl)
- wl("")
- else:
- wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...")
- if prefix == "":
- wl("")
- self.write_enum_values(
- enumerate(enum.value),
- value_type_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- if prefix == "" and not self.readable_stubs:
- wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
- wl("")
- def write_messages(
- self,
- messages: Iterable[d.DescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- for i, desc in enumerate(messages):
- qualified_name = prefix + desc.name
- # Reproduce some hardcoded logic from the protobuf implementation - where
- # some specific "well_known_types" generated protos to have additional
- # base classes
- addl_base = ""
- if self.fd.package + "." + desc.name in WKTBASES:
- # chop off the .proto - and import the well known type
- # eg `from google.protobuf.duration import Duration`
- well_known_type = WKTBASES[self.fd.package + "." + desc.name]
- addl_base = ", " + self._import(
- "google.protobuf.internal.well_known_types",
- well_known_type.__name__,
- )
- class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
- message_class = self._import("google.protobuf.message", "Message")
- wl(f"class {class_name}({message_class}{addl_base}):")
- with self._indent():
- scl = scl_prefix + [i]
- if self._write_comments(scl):
- wl("")
- desc_type = self._import("google.protobuf.descriptor", "Descriptor")
- wl(f"DESCRIPTOR: {desc_type}")
- wl("")
- # Nested enums/messages
- self.write_enums(
- desc.enum_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
- )
- self.write_messages(
- desc.nested_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
- )
- # integer constants for field numbers
- for f in desc.field:
- wl(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
- for idx, field in enumerate(desc.field):
- if field.name in PYTHON_RESERVED:
- continue
- field_type = self.python_type(field)
- if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED:
- # Scalar non repeated fields are r/w
- wl(f"{field.name}: {field_type}")
- self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx])
- else:
- # r/o Getters for non-scalar fields and scalar-repeated fields
- scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
- wl("@property")
- body = " ..." if not self._has_comments(scl_field) else ""
- wl(f"def {field.name}(self) -> {field_type}:{body}")
- if self._has_comments(scl_field):
- with self._indent():
- self._write_comments(scl_field)
- self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER])
- # Constructor
- wl("def __init__(")
- with self._indent():
- if any(f.name == "self" for f in desc.field):
- wl("# pyright: reportSelfClsParameterName=false")
- wl("self_,")
- else:
- wl("self,")
- with self._indent():
- constructor_fields = [f for f in desc.field if f.name not in PYTHON_RESERVED]
- if len(constructor_fields) > 0:
- # Only positional args allowed
- # See https://github.com/nipunn1313/mypy-protobuf/issues/71
- wl("*,")
- for field in constructor_fields:
- field_type = self.python_type(field, generic_container=True)
- if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional:
- wl(f"{field.name}: {field_type} = ...,")
- else:
- wl(f"{field.name}: {field_type} | None = ...,")
- wl(") -> None: ...")
- self.write_stringly_typed_fields(desc)
- if prefix == "" and not self.readable_stubs:
- wl("")
- wl(f"{_mangle_global_identifier(class_name)} = {class_name}")
- wl("")
- def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
- """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
- wl = self._write_line
- # HasField, ClearField, WhichOneof accepts both bytes/str
- # HasField only supports singular. ClearField supports repeated as well
- # In proto3, HasField only supports message fields and optional fields
- # HasField always supports oneof fields
- hf_fields = [f.name for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))]
- cf_fields = [f.name for f in desc.field]
- wo_fields = {oneof.name: [f.name for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)}
- hf_fields.extend(wo_fields.keys())
- cf_fields.extend(wo_fields.keys())
- hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields))
- cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields))
- if not hf_fields and not cf_fields and not wo_fields:
- return
- if hf_fields:
- wl(
- "def HasField(self, field_name: {}[{}]) -> {}: ...",
- self._import("typing_extensions", "Literal"),
- hf_fields_text,
- self._builtin("bool"),
- )
- if cf_fields:
- wl(
- "def ClearField(self, field_name: {}[{}]) -> None: ...",
- self._import("typing_extensions", "Literal"),
- cf_fields_text,
- )
- for wo_field, members in sorted(wo_fields.items()):
- if len(wo_fields) > 1:
- wl("@{}", self._import("typing", "overload"))
- wl(
- "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...",
- self._import("typing_extensions", "Literal"),
- # Accepts both str and bytes
- f'"{wo_field}", b"{wo_field}"',
- self._import("typing_extensions", "Literal"),
- # Returns `str`
- ", ".join(f'"{m}"' for m in members),
- )
- def write_extensions(
- self,
- extensions: Sequence[d.FieldDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- for ext in extensions:
- wl(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
- for i, ext in enumerate(extensions):
- scl = scl_prefix + [i]
- wl(
- "{}: {}[{}, {}]",
- ext.name,
- self._import(
- "google.protobuf.internal.extension_dict",
- "_ExtensionFieldDescriptor",
- ),
- self._import_message(ext.extendee),
- self.python_type(ext),
- )
- self._write_comments(scl)
- def write_methods(
- self,
- service: d.ServiceDescriptorProto,
- class_name: str,
- is_abstract: bool,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- wl(
- "DESCRIPTOR: {}",
- self._import("google.protobuf.descriptor", "ServiceDescriptor"),
- )
- methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
- if not methods:
- wl("...")
- for i, method in methods:
- if is_abstract:
- wl("@{}", self._import("abc", "abstractmethod"))
- wl(f"def {method.name}(")
- with self._indent():
- wl(f"inst: {class_name},")
- wl(
- "rpc_controller: {},",
- self._import("google.protobuf.service", "RpcController"),
- )
- wl("request: {},", self._import_message(method.input_type))
- wl(
- "callback: {}[[{}], None] | None{},",
- self._import("collections.abc", "Callable"),
- self._import_message(method.output_type),
- "" if is_abstract else " = ...",
- )
- scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- wl(
- ") -> {}[{}]:{}",
- self._import("concurrent.futures", "Future"),
- self._import_message(method.output_type),
- " ..." if not self._has_comments(scl_method) else "",
- )
- if self._has_comments(scl_method):
- with self._indent():
- if not self._write_comments(scl_method):
- wl("...")
- def write_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- for i, service in enumerate(services):
- scl = scl_prefix + [i]
- class_name = service.name if service.name not in PYTHON_RESERVED else "_r_" + service.name
- # The service definition interface
- wl(
- "class {}({}, metaclass={}):",
- class_name,
- self._import("google.protobuf.service", "Service"),
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- if self._write_comments(scl):
- wl("")
- self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl)
- wl("")
- # The stub client
- stub_class_name = service.name + "_Stub"
- wl("class {}({}):", stub_class_name, class_name)
- with self._indent():
- if self._write_comments(scl):
- wl("")
- wl(
- "def __init__(self, rpc_channel: {}) -> None: ...",
- self._import("google.protobuf.service", "RpcChannel"),
- )
- self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl)
- wl("")
- def _import_casttype(self, casttype: str) -> str:
- split = casttype.split(".")
- assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
- pkg = split[0].replace("/", ".")
- return self._import(pkg, split[1])
- def _map_key_value_types(
- self,
- map_field: d.FieldDescriptorProto,
- key_field: d.FieldDescriptorProto,
- value_field: d.FieldDescriptorProto,
- ) -> Tuple[str, str]:
- oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype]
- if oldstyle_keytype:
- print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr)
- key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype
- ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field)
- oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype]
- if oldstyle_valuetype:
- print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr)
- value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype]
- vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field)
- return ktype, vtype
- def _callable_type(self, method: d.MethodDescriptorProto) -> str:
- if method.client_streaming:
- if method.server_streaming:
- return self._import("grpc", "StreamStreamMultiCallable")
- else:
- return self._import("grpc", "StreamUnaryMultiCallable")
- else:
- if method.server_streaming:
- return self._import("grpc", "UnaryStreamMultiCallable")
- else:
- return self._import("grpc", "UnaryUnaryMultiCallable")
- def _input_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
- result = self._import_message(method.input_type)
- if use_stream_iterator and method.client_streaming:
- result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
- return result
- def _output_type(self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True) -> str:
- result = self._import_message(method.output_type)
- if use_stream_iterator and method.server_streaming:
- result = f"{self._import('collections.abc', 'Iterator')}[{result}]"
- return result
- def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
- wl = self._write_line
- methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
- if not methods:
- wl("...")
- wl("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- wl("@{}", self._import("abc", "abstractmethod"))
- wl("def {}(", method.name)
- with self._indent():
- wl("self,")
- input_name = "request_iterator" if method.client_streaming else "request"
- input_type = self._input_type(method)
- wl(f"{input_name}: {input_type},")
- wl("context: {},", self._import("grpc", "ServicerContext"))
- wl(
- ") -> {}:{}",
- self._output_type(method),
- " ..." if not self._has_comments(scl) else "",
- ),
- if self._has_comments(scl):
- with self._indent():
- if not self._write_comments(scl):
- wl("...")
- def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None:
- wl = self._write_line
- methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED]
- if not methods:
- wl("...")
- wl("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- wl("{}: {}[", method.name, self._callable_type(method))
- with self._indent():
- wl("{},", self._input_type(method, False))
- wl("{},", self._output_type(method, False))
- wl("]")
- self._write_comments(scl)
- def write_grpc_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- wl = self._write_line
- for i, service in enumerate(services):
- if service.name in PYTHON_RESERVED:
- continue
- scl = scl_prefix + [i]
- # The stub client
- wl(f"class {service.name}Stub:")
- with self._indent():
- if self._write_comments(scl):
- wl("")
- wl(
- "def __init__(self, channel: {}) -> None: ...",
- self._import("grpc", "Channel"),
- )
- self.write_grpc_stub_methods(service, scl)
- wl("")
- # The service definition interface
- wl(
- "class {}Servicer(metaclass={}):",
- service.name,
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- if self._write_comments(scl):
- wl("")
- self.write_grpc_methods(service, scl)
- wl("")
- wl(
- "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
- service.name,
- service.name,
- self._import("grpc", "Server"),
- )
- wl("")
- def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str:
- """
- generic_container
- if set, type the field with generic interfaces. Eg.
- - Iterable[int] rather than RepeatedScalarFieldContainer[int]
- - Mapping[k, v] rather than MessageMap[k, v]
- Can be useful for input types (eg constructor)
- """
- oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype]
- if oldstyle_casttype:
- print(f"Warning: Field {field.name}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr)
- casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype
- if casttype:
- return self._import_casttype(casttype)
- mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
- d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
- d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"),
- d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
- d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"),
- d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name),
- d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name),
- }
- assert field.type in mapping, "Unrecognized type: " + repr(field.type)
- field_type = mapping[field.type]()
- # For non-repeated fields, we're done!
- if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
- return field_type
- # Scalar repeated fields go in RepeatedScalarFieldContainer
- if is_scalar(field):
- container = (
- self._import("collections.abc", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedScalarFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
- # non-scalar repeated map fields go in ScalarMap/MessageMap
- msg = self.descriptors.messages[field.type_name]
- if msg.options.map_entry:
- # map generates a special Entry wrapper message
- if generic_container:
- container = self._import("collections.abc", "Mapping")
- elif is_scalar(msg.field[1]):
- container = self._import("google.protobuf.internal.containers", "ScalarMap")
- else:
- container = self._import("google.protobuf.internal.containers", "MessageMap")
- ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
- return f"{container}[{ktype}, {vtype}]"
- # non-scalar repetated fields go in RepeatedCompositeFieldContainer
- container = (
- self._import("collections.abc", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedCompositeFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
- def write(self) -> str:
- # save current module content, so that imports and module docstring can be inserted
- saved_lines = self.lines
- self.lines = []
- # module docstring may exist as comment before syntax (optional) or package name
- if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]):
- self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER])
- if self.lines:
- assert self.lines[0].startswith('"""')
- self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}'
- else:
- self._write_line(f'"""{HEADER}"""')
- for reexport_idx in self.fd.public_dependency:
- reexport_file = self.fd.dependency[reexport_idx]
- reexport_fd = self.descriptors.files[reexport_file]
- reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
- names = [m.name for m in reexport_fd.message_type] + [m.name for m in reexport_fd.enum_type] + [v.name for m in reexport_fd.enum_type for v in m.value] + [m.name for m in reexport_fd.extension]
- if reexport_fd.options.py_generic_services:
- names.extend(m.name for m in reexport_fd.service)
- if names:
- # n,n to force a reexport (from x import y as y)
- self.from_imports[reexport_imp].update((n, n) for n in names)
- if self.typing_extensions_min:
- self.imports.add("sys")
- for pkg in sorted(self.imports):
- self._write_line(f"import {pkg}")
- if self.typing_extensions_min:
- self._write_line("")
- self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:")
- self._write_line(" import typing as typing_extensions")
- self._write_line("else:")
- self._write_line(" import typing_extensions")
- for pkg, items in sorted(self.from_imports.items()):
- self._write_line(f"from {pkg} import (")
- for (name, reexport_name) in sorted(items):
- if reexport_name is None:
- self._write_line(f" {name},")
- else:
- self._write_line(f" {name} as {reexport_name},")
- self._write_line(")")
- self._write_line("")
- # restore module content
- self.lines += saved_lines
- content = "\n".join(self.lines)
- if not content.endswith("\n"):
- content = content + "\n"
- return content
- def is_scalar(fd: d.FieldDescriptorProto) -> bool:
- return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP)
- def generate_mypy_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- ) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=False,
- )
- pkg_writer.write_module_attributes()
- pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER])
- pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER])
- pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER])
- if fd.options.py_generic_services:
- pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
- output.content = pkg_writer.write()
- def generate_mypy_grpc_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- ) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=True,
- )
- pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER])
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
- output.content = pkg_writer.write()
- @contextmanager
- def code_generation() -> Iterator[
- Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
- ]:
- if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
- print("mypy-protobuf " + __version__)
- sys.exit(0)
- # Read request message from stdin
- data = sys.stdin.buffer.read()
- # Parse request
- request = plugin_pb2.CodeGeneratorRequest()
- request.ParseFromString(data)
- # Create response
- response = plugin_pb2.CodeGeneratorResponse()
- # Declare support for optional proto3 fields
- response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
- yield request, response
- # Serialise response message
- output = response.SerializeToString()
- # Write to stdout
- sys.stdout.buffer.write(output)
- def main() -> None:
- # Generate mypy
- with code_generation() as (request, response):
- generate_mypy_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
- def grpc() -> None:
- # Generate grpc mypy
- with code_generation() as (request, response):
- generate_mypy_grpc_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
- if __name__ == "__main__":
- main()
|