123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086 |
- #!/usr/bin/env python
- """Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation"""
- import os
- import sys
- from collections import defaultdict
- from contextlib import contextmanager
- from functools import wraps
- from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Set,
- Sequence,
- Tuple,
- )
- import google.protobuf.descriptor_pb2 as d
- from google.protobuf.compiler import plugin_pb2 as plugin_pb2
- from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
- from google.protobuf.internal.well_known_types import WKTBASES
- from . import extensions_pb2
- __version__ = "3.2.0"
- # SourceCodeLocation is defined by `message Location` here
- # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto
- SourceCodeLocation = List[int]
- # So phabricator doesn't think mypy_protobuf.py is generated
- GENERATED = "@ge" + "nerated"
- HEADER = f"""\"\"\"
- {GENERATED} by mypy-protobuf. Do not edit manually!
- isort:skip_file
- \"\"\"
- """
- # See https://github.com/dropbox/mypy-protobuf/issues/73 for details
- PYTHON_RESERVED = {
- "False",
- "None",
- "True",
- "and",
- "as",
- "async",
- "await",
- "assert",
- "break",
- "class",
- "continue",
- "def",
- "del",
- "elif",
- "else",
- "except",
- "finally",
- "for",
- "from",
- "global",
- "if",
- "import",
- "in",
- "is",
- "lambda",
- "nonlocal",
- "not",
- "or",
- "pass",
- "raise",
- "return",
- "try",
- "while",
- "with",
- "yield",
- }
- PROTO_ENUM_RESERVED = {
- "Name",
- "Value",
- "keys",
- "values",
- "items",
- }
- def _mangle_global_identifier(name: str) -> str:
- """
- Module level identifiers are mangled and aliased so that they can be disambiguated
- from fields/enum variants with the same name within the file.
- Eg:
- Enum variant `Name` or message field `Name` might conflict with a top level
- message or enum named `Name`, so mangle it with a global___ prefix for
- internal references. Note that this doesn't affect inner enums/messages
- because they get fuly qualified when referenced within a file"""
- return f"global___{name}"
- class Descriptors(object):
- def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None:
- files = {f.name: f for f in request.proto_file}
- to_generate = {n: files[n] for n in request.file_to_generate}
- self.files: Dict[str, d.FileDescriptorProto] = files
- self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate
- self.messages: Dict[str, d.DescriptorProto] = {}
- self.message_to_fd: Dict[str, d.FileDescriptorProto] = {}
- def _add_enums(
- enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for enum in enums:
- self.message_to_fd[prefix + enum.name] = _fd
- self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd
- def _add_messages(
- messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]",
- prefix: str,
- _fd: d.FileDescriptorProto,
- ) -> None:
- for message in messages:
- self.messages[prefix + message.name] = message
- self.message_to_fd[prefix + message.name] = _fd
- sub_prefix = prefix + message.name + "."
- _add_messages(message.nested_type, sub_prefix, _fd)
- _add_enums(message.enum_type, sub_prefix, _fd)
- for fd in request.proto_file:
- start_prefix = "." + fd.package + "." if fd.package else "."
- _add_messages(fd.message_type, start_prefix, fd)
- _add_enums(fd.enum_type, start_prefix, fd)
- class PkgWriter(object):
- """Writes a single pyi file"""
- def __init__(
- self,
- fd: d.FileDescriptorProto,
- descriptors: Descriptors,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- grpc: bool,
- ) -> None:
- self.fd = fd
- self.descriptors = descriptors
- self.readable_stubs = readable_stubs
- self.relax_strict_optional_primitives = relax_strict_optional_primitives
- self.grpc = grpc
- self.lines: List[str] = []
- self.indent = ""
- # Set of {x}, where {x} corresponds to to `import {x}`
- self.imports: Set[str] = set()
- # dictionary of x->(y,z) for `from {x} import {y} as {z}`
- # if {z} is None, then it shortens to `from {x} import {y}`
- self.from_imports: Dict[str, Set[Tuple[str, Optional[str]]]] = defaultdict(set)
- # Comments
- self.source_code_info_by_scl = {
- tuple(location.path): location for location in fd.source_code_info.location
- }
- def _import(self, path: str, name: str) -> str:
- """Imports a stdlib path and returns a handle to it
- eg. self._import("typing", "Optional") -> "Optional"
- """
- imp = path.replace("/", ".")
- if self.readable_stubs:
- self.from_imports[imp].add((name, None))
- return name
- else:
- self.imports.add(imp)
- return imp + "." + name
- def _import_message(self, name: str) -> str:
- """Import a referenced message and return a handle"""
- message_fd = self.descriptors.message_to_fd[name]
- assert message_fd.name.endswith(".proto")
- # Strip off package name
- if message_fd.package:
- assert name.startswith("." + message_fd.package + ".")
- name = name[len("." + message_fd.package + ".") :]
- else:
- assert name.startswith(".")
- name = name[1:]
- # Use prepended "_r_" to disambiguate message names that alias python reserved keywords
- split = name.split(".")
- for i, part in enumerate(split):
- if part in PYTHON_RESERVED:
- split[i] = "_r_" + part
- name = ".".join(split)
- # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files
- if not self.grpc and message_fd.name == self.fd.name:
- return name if self.readable_stubs else _mangle_global_identifier(name)
- # Not in file. Must import
- # Python generated code ignores proto packages, so the only relevant factor is
- # whether it is in the file or not.
- import_name = self._import(
- message_fd.name[:-6].replace("-", "_") + "_pb2", split[0]
- )
- remains = ".".join(split[1:])
- if not remains:
- return import_name
- # remains could either be a direct import of a nested enum or message
- # from another package.
- return import_name + "." + remains
- def _builtin(self, name: str) -> str:
- return self._import("builtins", name)
- @contextmanager
- def _indent(self) -> Iterator[None]:
- self.indent = self.indent + " "
- yield
- self.indent = self.indent[:-4]
- def _write_line(self, line: str, *args: Any) -> None:
- if args:
- line = line.format(*args)
- if line == "":
- self.lines.append(line)
- else:
- self.lines.append(self.indent + line)
- def _break_text(self, text_block: str) -> List[str]:
- if text_block == "":
- return []
- return [
- l[1:] if l.startswith(" ") else l for l in text_block.rstrip().split("\n")
- ]
- def _has_comments(self, scl: SourceCodeLocation) -> bool:
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- return sci_loc is not None and bool(
- sci_loc.leading_detached_comments
- or sci_loc.leading_comments
- or sci_loc.trailing_comments
- )
- def _write_comments(self, scl: SourceCodeLocation) -> bool:
- """Return true if any comments were written"""
- if not self._has_comments(scl):
- return False
- sci_loc = self.source_code_info_by_scl.get(tuple(scl))
- assert sci_loc is not None
- lines = []
- for leading_detached_comment in sci_loc.leading_detached_comments:
- lines.extend(self._break_text(leading_detached_comment))
- lines.append("")
- if sci_loc.leading_comments is not None:
- lines.extend(self._break_text(sci_loc.leading_comments))
- # Trailing comments also go in the header - to make sure it gets into the docstring
- if sci_loc.trailing_comments is not None:
- lines.extend(self._break_text(sci_loc.trailing_comments))
- lines = [
- # Escape triple-quotes that would otherwise end the docstring early.
- line.replace("\\", "\\\\").replace('"""', r"\"\"\"")
- for line in lines
- ]
- if len(lines) == 1:
- line = lines[0]
- if line.endswith(('"', "\\")):
- # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote,
- # insert some whitespace to separate it from the closing quotes.
- # This is not necessary with multiline comments
- # because in that case we always insert a newline before the trailing triple-quotes.
- line = line + " "
- self._write_line(f'"""{line}"""')
- else:
- for i, line in enumerate(lines):
- if i == 0:
- self._write_line(f'"""{line}')
- else:
- self._write_line(f"{line}")
- self._write_line('"""')
- return True
- def write_enum_values(
- self,
- values: Iterable[Tuple[int, d.EnumValueDescriptorProto]],
- value_type: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- for i, val in values:
- if val.name in PYTHON_RESERVED:
- continue
- scl = scl_prefix + [i]
- self._write_line(
- f"{val.name}: {value_type} # {val.number}",
- )
- if self._write_comments(scl):
- self._write_line("") # Extra newline to separate
- def write_module_attributes(self) -> None:
- l = self._write_line
- fd_type = self._import("google.protobuf.descriptor", "FileDescriptor")
- l(f"DESCRIPTOR: {fd_type}")
- l("")
- def write_enums(
- self,
- enums: Iterable[d.EnumDescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, enum in enumerate(enums):
- class_name = (
- enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name
- )
- value_type_fq = prefix + class_name + ".ValueType"
- enum_helper_class = "_" + enum.name
- value_type_helper_fq = prefix + enum_helper_class + ".ValueType"
- etw_helper_class = "_" + enum.name + "EnumTypeWrapper"
- scl = scl_prefix + [i]
- l(f"class {enum_helper_class}:")
- with self._indent():
- l(
- "ValueType = {}('ValueType', {})",
- self._import("typing", "NewType"),
- self._builtin("int"),
- )
- # Alias to the classic shorter definition "V"
- l("V: {} = ValueType", self._import("typing_extensions", "TypeAlias"))
- l(
- "class {}({}[{}], {}):",
- etw_helper_class,
- self._import(
- "google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"
- ),
- value_type_helper_fq,
- self._builtin("type"),
- )
- with self._indent():
- ed = self._import("google.protobuf.descriptor", "EnumDescriptor")
- l(f"DESCRIPTOR: {ed}")
- self.write_enum_values(
- [
- (i, v)
- for i, v in enumerate(enum.value)
- if v.name not in PROTO_ENUM_RESERVED
- ],
- value_type_helper_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- l(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):")
- with self._indent():
- self._write_comments(scl)
- l("pass")
- l("")
- self.write_enum_values(
- enumerate(enum.value),
- value_type_fq,
- scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER],
- )
- if prefix == "" and not self.readable_stubs:
- l(f"{_mangle_global_identifier(class_name)} = {class_name}")
- l("")
- l("")
- def write_messages(
- self,
- messages: Iterable[d.DescriptorProto],
- prefix: str,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, desc in enumerate(messages):
- qualified_name = prefix + desc.name
- # Reproduce some hardcoded logic from the protobuf implementation - where
- # some specific "well_known_types" generated protos to have additional
- # base classes
- addl_base = ""
- if self.fd.package + "." + desc.name in WKTBASES:
- # chop off the .proto - and import the well known type
- # eg `from google.protobuf.duration import Duration`
- well_known_type = WKTBASES[self.fd.package + "." + desc.name]
- addl_base = ", " + self._import(
- "google.protobuf.internal.well_known_types",
- well_known_type.__name__,
- )
- class_name = (
- desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name
- )
- message_class = self._import("google.protobuf.message", "Message")
- l(f"class {class_name}({message_class}{addl_base}):")
- with self._indent():
- scl = scl_prefix + [i]
- self._write_comments(scl)
- desc_type = self._import("google.protobuf.descriptor", "Descriptor")
- l(f"DESCRIPTOR: {desc_type}")
- # Nested enums/messages
- self.write_enums(
- desc.enum_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER],
- )
- self.write_messages(
- desc.nested_type,
- qualified_name + ".",
- scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER],
- )
- # integer constants for field numbers
- for f in desc.field:
- l(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
- for idx, field in enumerate(desc.field):
- if field.name in PYTHON_RESERVED:
- continue
- field_type = self.python_type(field)
- if (
- is_scalar(field)
- and field.label != d.FieldDescriptorProto.LABEL_REPEATED
- ):
- # Scalar non repeated fields are r/w
- l(f"{field.name}: {field_type}")
- if self._write_comments(
- scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
- ):
- l("")
- else:
- # r/o Getters for non-scalar fields and scalar-repeated fields
- scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]
- l("@property")
- body = " ..." if not self._has_comments(scl_field) else ""
- l(f"def {field.name}(self) -> {field_type}:{body}")
- if self._has_comments(scl_field):
- with self._indent():
- self._write_comments(scl_field)
- l("pass")
- self.write_extensions(
- desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER]
- )
- # Constructor
- if any(f.name == "self" for f in desc.field):
- l("# pyright: reportSelfClsParameterName=false")
- l(f"def __init__(self_,")
- else:
- l(f"def __init__(self,")
- with self._indent():
- constructor_fields = [
- f for f in desc.field if f.name not in PYTHON_RESERVED
- ]
- if len(constructor_fields) > 0:
- # Only positional args allowed
- # See https://github.com/dropbox/mypy-protobuf/issues/71
- l("*,")
- for field in constructor_fields:
- field_type = self.python_type(field, generic_container=True)
- if (
- self.fd.syntax == "proto3"
- and is_scalar(field)
- and field.label != d.FieldDescriptorProto.LABEL_REPEATED
- and not self.relax_strict_optional_primitives
- and not field.proto3_optional
- ):
- l(f"{field.name}: {field_type} = ...,")
- else:
- opt = self._import("typing", "Optional")
- l(f"{field.name}: {opt}[{field_type}] = ...,")
- l(") -> None: ...")
- self.write_stringly_typed_fields(desc)
- if prefix == "" and not self.readable_stubs:
- l(f"{_mangle_global_identifier(class_name)} = {class_name}")
- l("")
- def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None:
- """Type the stringly-typed methods as a Union[Literal, Literal ...]"""
- l = self._write_line
- # HasField, ClearField, WhichOneof accepts both bytes/str
- # HasField only supports singular. ClearField supports repeated as well
- # In proto3, HasField only supports message fields and optional fields
- # HasField always supports oneof fields
- hf_fields = [
- f.name
- for f in desc.field
- if f.HasField("oneof_index")
- or (
- f.label != d.FieldDescriptorProto.LABEL_REPEATED
- and (
- self.fd.syntax != "proto3"
- or f.type == d.FieldDescriptorProto.TYPE_MESSAGE
- or f.proto3_optional
- )
- )
- ]
- cf_fields = [f.name for f in desc.field]
- wo_fields = {
- oneof.name: [
- f.name
- for f in desc.field
- if f.HasField("oneof_index") and f.oneof_index == idx
- ]
- for idx, oneof in enumerate(desc.oneof_decl)
- }
- hf_fields.extend(wo_fields.keys())
- cf_fields.extend(wo_fields.keys())
- hf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in hf_fields))
- cf_fields_text = ",".join(sorted(f'"{name}",b"{name}"' for name in cf_fields))
- if not hf_fields and not cf_fields and not wo_fields:
- return
- if hf_fields:
- l(
- "def HasField(self, field_name: {}[{}]) -> {}: ...",
- self._import("typing_extensions", "Literal"),
- hf_fields_text,
- self._builtin("bool"),
- )
- if cf_fields:
- l(
- "def ClearField(self, field_name: {}[{}]) -> None: ...",
- self._import("typing_extensions", "Literal"),
- cf_fields_text,
- )
- for wo_field, members in sorted(wo_fields.items()):
- if len(wo_fields) > 1:
- l("@{}", self._import("typing", "overload"))
- l(
- "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}[{}]]: ...",
- self._import("typing_extensions", "Literal"),
- # Accepts both str and bytes
- f'"{wo_field}",b"{wo_field}"',
- self._import("typing", "Optional"),
- self._import("typing_extensions", "Literal"),
- # Returns `str`
- ",".join(f'"{m}"' for m in members),
- )
- def write_extensions(
- self,
- extensions: Sequence[d.FieldDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for ext in extensions:
- l(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}")
- for i, ext in enumerate(extensions):
- scl = scl_prefix + [i]
- l(
- "{}: {}[{}, {}]",
- ext.name,
- self._import(
- "google.protobuf.internal.extension_dict",
- "_ExtensionFieldDescriptor",
- ),
- self._import_message(ext.extendee),
- self.python_type(ext),
- )
- self._write_comments(scl)
- l("")
- def write_methods(
- self,
- service: d.ServiceDescriptorProto,
- class_name: str,
- is_abstract: bool,
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- l(
- "DESCRIPTOR: {}",
- self._import("google.protobuf.descriptor", "ServiceDescriptor"),
- )
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- for i, method in methods:
- if is_abstract:
- l("@{}", self._import("abc", "abstractmethod"))
- l(f"def {method.name}(")
- with self._indent():
- l(f"inst: {class_name},")
- l(
- "rpc_controller: {},",
- self._import("google.protobuf.service", "RpcController"),
- )
- l("request: {},", self._import_message(method.input_type))
- l(
- "callback: {}[{}[[{}], None]]{},",
- self._import("typing", "Optional"),
- self._import("typing", "Callable"),
- self._import_message(method.output_type),
- "" if is_abstract else " = None",
- )
- scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- l(
- ") -> {}[{}]:{}",
- self._import("concurrent.futures", "Future"),
- self._import_message(method.output_type),
- " ..." if not self._has_comments(scl_method) else "",
- )
- if self._has_comments(scl_method):
- with self._indent():
- self._write_comments(scl_method)
- l("pass")
- def write_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, service in enumerate(services):
- scl = scl_prefix + [i]
- class_name = (
- service.name
- if service.name not in PYTHON_RESERVED
- else "_r_" + service.name
- )
- # The service definition interface
- l(
- "class {}({}, metaclass={}):",
- class_name,
- self._import("google.protobuf.service", "Service"),
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- self._write_comments(scl)
- self.write_methods(
- service, class_name, is_abstract=True, scl_prefix=scl
- )
- # The stub client
- stub_class_name = service.name + "_Stub"
- l("class {}({}):", stub_class_name, class_name)
- with self._indent():
- self._write_comments(scl)
- l(
- "def __init__(self, rpc_channel: {}) -> None: ...",
- self._import("google.protobuf.service", "RpcChannel"),
- )
- self.write_methods(
- service, stub_class_name, is_abstract=False, scl_prefix=scl
- )
- def _import_casttype(self, casttype: str) -> str:
- split = casttype.split(".")
- assert (
- len(split) == 2
- ), "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile"
- pkg = split[0].replace("/", ".")
- return self._import(pkg, split[1])
- def _map_key_value_types(
- self,
- map_field: d.FieldDescriptorProto,
- key_field: d.FieldDescriptorProto,
- value_field: d.FieldDescriptorProto,
- ) -> Tuple[str, str]:
- key_casttype = map_field.options.Extensions[extensions_pb2.keytype]
- ktype = (
- self._import_casttype(key_casttype)
- if key_casttype
- else self.python_type(key_field)
- )
- value_casttype = map_field.options.Extensions[extensions_pb2.valuetype]
- vtype = (
- self._import_casttype(value_casttype)
- if value_casttype
- else self.python_type(value_field)
- )
- return ktype, vtype
- def _callable_type(self, method: d.MethodDescriptorProto) -> str:
- if method.client_streaming:
- if method.server_streaming:
- return self._import("grpc", "StreamStreamMultiCallable")
- else:
- return self._import("grpc", "StreamUnaryMultiCallable")
- else:
- if method.server_streaming:
- return self._import("grpc", "UnaryStreamMultiCallable")
- else:
- return self._import("grpc", "UnaryUnaryMultiCallable")
- def _input_type(
- self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
- ) -> str:
- result = self._import_message(method.input_type)
- if use_stream_iterator and method.client_streaming:
- result = f"{self._import('typing', 'Iterator')}[{result}]"
- return result
- def _output_type(
- self, method: d.MethodDescriptorProto, use_stream_iterator: bool = True
- ) -> str:
- result = self._import_message(method.output_type)
- if use_stream_iterator and method.server_streaming:
- result = f"{self._import('typing', 'Iterator')}[{result}]"
- return result
- def write_grpc_methods(
- self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
- ) -> None:
- l = self._write_line
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- l("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- l("@{}", self._import("abc", "abstractmethod"))
- l("def {}(self,", method.name)
- with self._indent():
- input_name = (
- "request_iterator" if method.client_streaming else "request"
- )
- input_type = self._input_type(method)
- l(f"{input_name}: {input_type},")
- l("context: {},", self._import("grpc", "ServicerContext"))
- l(
- ") -> {}:{}",
- self._output_type(method),
- " ..." if not self._has_comments(scl) else "",
- ),
- if self._has_comments(scl):
- with self._indent():
- self._write_comments(scl)
- l("pass")
- l("")
- def write_grpc_stub_methods(
- self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation
- ) -> None:
- l = self._write_line
- methods = [
- (i, m)
- for i, m in enumerate(service.method)
- if m.name not in PYTHON_RESERVED
- ]
- if not methods:
- l("pass")
- l("")
- for i, method in methods:
- scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i]
- l("{}: {}[", method.name, self._callable_type(method))
- with self._indent():
- l("{},", self._input_type(method, False))
- l("{}]", self._output_type(method, False))
- self._write_comments(scl)
- l("")
- def write_grpc_services(
- self,
- services: Iterable[d.ServiceDescriptorProto],
- scl_prefix: SourceCodeLocation,
- ) -> None:
- l = self._write_line
- for i, service in enumerate(services):
- if service.name in PYTHON_RESERVED:
- continue
- scl = scl_prefix + [i]
- # The stub client
- l(f"class {service.name}Stub:")
- with self._indent():
- self._write_comments(scl)
- l(
- "def __init__(self, channel: {}) -> None: ...",
- self._import("grpc", "Channel"),
- )
- self.write_grpc_stub_methods(service, scl)
- l("")
- # The service definition interface
- l(
- "class {}Servicer(metaclass={}):",
- service.name,
- self._import("abc", "ABCMeta"),
- )
- with self._indent():
- self._write_comments(scl)
- self.write_grpc_methods(service, scl)
- l("")
- l(
- "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...",
- service.name,
- service.name,
- self._import("grpc", "Server"),
- )
- l("")
- def python_type(
- self, field: d.FieldDescriptorProto, generic_container: bool = False
- ) -> str:
- """
- generic_container
- if set, type the field with generic interfaces. Eg.
- - Iterable[int] rather than RepeatedScalarFieldContainer[int]
- - Mapping[k, v] rather than MessageMap[k, v]
- Can be useful for input types (eg constructor)
- """
- casttype = field.options.Extensions[extensions_pb2.casttype]
- if casttype:
- return self._import_casttype(casttype)
- mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = {
- d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"),
- d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"),
- d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"),
- d.FieldDescriptorProto.TYPE_STRING: lambda: self._import("typing", "Text"),
- d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"),
- d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(
- field.type_name + ".ValueType"
- ),
- d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(
- field.type_name
- ),
- d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(
- field.type_name
- ),
- }
- assert field.type in mapping, "Unrecognized type: " + repr(field.type)
- field_type = mapping[field.type]()
- # For non-repeated fields, we're done!
- if field.label != d.FieldDescriptorProto.LABEL_REPEATED:
- return field_type
- # Scalar repeated fields go in RepeatedScalarFieldContainer
- if is_scalar(field):
- container = (
- self._import("typing", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedScalarFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
- # non-scalar repeated map fields go in ScalarMap/MessageMap
- msg = self.descriptors.messages[field.type_name]
- if msg.options.map_entry:
- # map generates a special Entry wrapper message
- if generic_container:
- container = self._import("typing", "Mapping")
- elif is_scalar(msg.field[1]):
- container = self._import(
- "google.protobuf.internal.containers", "ScalarMap"
- )
- else:
- container = self._import(
- "google.protobuf.internal.containers", "MessageMap"
- )
- ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1])
- return f"{container}[{ktype}, {vtype}]"
- # non-scalar repetated fields go in RepeatedCompositeFieldContainer
- container = (
- self._import("typing", "Iterable")
- if generic_container
- else self._import(
- "google.protobuf.internal.containers",
- "RepeatedCompositeFieldContainer",
- )
- )
- return f"{container}[{field_type}]"
- def write(self) -> str:
- for reexport_idx in self.fd.public_dependency:
- reexport_file = self.fd.dependency[reexport_idx]
- reexport_fd = self.descriptors.files[reexport_file]
- reexport_imp = (
- reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2"
- )
- names = (
- [m.name for m in reexport_fd.message_type]
- + [m.name for m in reexport_fd.enum_type]
- + [v.name for m in reexport_fd.enum_type for v in m.value]
- + [m.name for m in reexport_fd.extension]
- )
- if reexport_fd.options.py_generic_services:
- names.extend(m.name for m in reexport_fd.service)
- if names:
- # n,n to force a reexport (from x import y as y)
- self.from_imports[reexport_imp].update((n, n) for n in names)
- import_lines = []
- for pkg in sorted(self.imports):
- import_lines.append(f"import {pkg}")
- for pkg, items in sorted(self.from_imports.items()):
- import_lines.append(f"from {pkg} import (")
- for (name, reexport_name) in sorted(items):
- if reexport_name is None:
- import_lines.append(f" {name},")
- else:
- import_lines.append(f" {name} as {reexport_name},")
- import_lines.append(")\n")
- import_lines.append("")
- return "\n".join(import_lines + self.lines)
- def is_scalar(fd: d.FieldDescriptorProto) -> bool:
- return not (
- fd.type == d.FieldDescriptorProto.TYPE_MESSAGE
- or fd.type == d.FieldDescriptorProto.TYPE_GROUP
- )
- def generate_mypy_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- ) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=False,
- )
- pkg_writer.write_module_attributes()
- pkg_writer.write_enums(
- fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER]
- )
- pkg_writer.write_messages(
- fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER]
- )
- pkg_writer.write_extensions(
- fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER]
- )
- if fd.options.py_generic_services:
- pkg_writer.write_services(
- fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
- )
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi"
- output.content = HEADER + pkg_writer.write()
- def generate_mypy_grpc_stubs(
- descriptors: Descriptors,
- response: plugin_pb2.CodeGeneratorResponse,
- quiet: bool,
- readable_stubs: bool,
- relax_strict_optional_primitives: bool,
- ) -> None:
- for name, fd in descriptors.to_generate.items():
- pkg_writer = PkgWriter(
- fd,
- descriptors,
- readable_stubs,
- relax_strict_optional_primitives,
- grpc=True,
- )
- pkg_writer.write_grpc_services(
- fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]
- )
- assert name == fd.name
- assert fd.name.endswith(".proto")
- output = response.file.add()
- output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi"
- output.content = HEADER + pkg_writer.write()
- @contextmanager
- def code_generation() -> Iterator[
- Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],
- ]:
- if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"):
- print("mypy-protobuf " + __version__)
- sys.exit(0)
- # Read request message from stdin
- data = sys.stdin.buffer.read()
- # Parse request
- request = plugin_pb2.CodeGeneratorRequest()
- request.ParseFromString(data)
- # Create response
- response = plugin_pb2.CodeGeneratorResponse()
- # Declare support for optional proto3 fields
- response.supported_features |= (
- plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
- )
- yield request, response
- # Serialise response message
- output = response.SerializeToString()
- # Write to stdout
- sys.stdout.buffer.write(output)
- def main() -> None:
- # Generate mypy
- with code_generation() as (request, response):
- generate_mypy_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
- def grpc() -> None:
- # Generate grpc mypy
- with code_generation() as (request, response):
- generate_mypy_grpc_stubs(
- Descriptors(request),
- response,
- "quiet" in request.parameter,
- "readable_stubs" in request.parameter,
- "relax_strict_optional_primitives" in request.parameter,
- )
- if __name__ == "__main__":
- main()
|