#!/usr/bin/env python """Protoc Plugin to generate mypy stubs.""" from __future__ import annotations import sys from collections import defaultdict from contextlib import contextmanager from typing import ( Any, Callable, Dict, Iterable, Iterator, List, Optional, Set, Sequence, Tuple, ) import google.protobuf.descriptor_pb2 as d from google.protobuf.compiler import plugin_pb2 as plugin_pb2 from google.protobuf.internal.containers import RepeatedCompositeFieldContainer from google.protobuf.internal.well_known_types import WKTBASES from . import extensions_pb2 __version__ = "3.6.0" # SourceCodeLocation is defined by `message Location` here # https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/descriptor.proto SourceCodeLocation = List[int] # So phabricator doesn't think mypy_protobuf.py is generated GENERATED = "@ge" + "nerated" HEADER = f""" {GENERATED} by mypy-protobuf. Do not edit manually! isort:skip_file """ # See https://github.com/nipunn1313/mypy-protobuf/issues/73 for details PYTHON_RESERVED = { "False", "None", "True", "and", "as", "async", "await", "assert", "break", "class", "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while", "with", "yield", } PROTO_ENUM_RESERVED = { "Name", "Value", "keys", "values", "items", } def _mangle_global_identifier(name: str) -> str: """ Module level identifiers are mangled and aliased so that they can be disambiguated from fields/enum variants with the same name within the file. Eg: Enum variant `Name` or message field `Name` might conflict with a top level message or enum named `Name`, so mangle it with a global___ prefix for internal references. Note that this doesn't affect inner enums/messages because they get fuly qualified when referenced within a file""" return f"global___{name}" class Descriptors(object): def __init__(self, request: plugin_pb2.CodeGeneratorRequest) -> None: files = {f.name: f for f in request.proto_file} to_generate = {n: files[n] for n in request.file_to_generate} self.files: Dict[str, d.FileDescriptorProto] = files self.to_generate: Dict[str, d.FileDescriptorProto] = to_generate self.messages: Dict[str, d.DescriptorProto] = {} self.message_to_fd: Dict[str, d.FileDescriptorProto] = {} def _add_enums( enums: "RepeatedCompositeFieldContainer[d.EnumDescriptorProto]", prefix: str, _fd: d.FileDescriptorProto, ) -> None: for enum in enums: self.message_to_fd[prefix + enum.name] = _fd self.message_to_fd[prefix + enum.name + ".ValueType"] = _fd def _add_messages( messages: "RepeatedCompositeFieldContainer[d.DescriptorProto]", prefix: str, _fd: d.FileDescriptorProto, ) -> None: for message in messages: self.messages[prefix + message.name] = message self.message_to_fd[prefix + message.name] = _fd sub_prefix = prefix + message.name + "." _add_messages(message.nested_type, sub_prefix, _fd) _add_enums(message.enum_type, sub_prefix, _fd) for fd in request.proto_file: start_prefix = "." + fd.package + "." if fd.package else "." _add_messages(fd.message_type, start_prefix, fd) _add_enums(fd.enum_type, start_prefix, fd) class PkgWriter(object): """Writes a single pyi file""" def __init__( self, fd: d.FileDescriptorProto, descriptors: Descriptors, readable_stubs: bool, relax_strict_optional_primitives: bool, grpc: bool, ) -> None: self.fd = fd self.descriptors = descriptors self.readable_stubs = readable_stubs self.relax_strict_optional_primitives = relax_strict_optional_primitives self.grpc = grpc self.lines: List[str] = [] self.indent = "" # Set of {x}, where {x} corresponds to to `import {x}` self.imports: Set[str] = set() # dictionary of x->(y,z) for `from {x} import {y} as {z}` # if {z} is None, then it shortens to `from {x} import {y}` self.from_imports: Dict[str, Set[Tuple[str, str | None]]] = defaultdict(set) self.typing_extensions_min: Optional[Tuple[int, int]] = None # Comments self.source_code_info_by_scl = {tuple(location.path): location for location in fd.source_code_info.location} def _import(self, path: str, name: str) -> str: """Imports a stdlib path and returns a handle to it eg. self._import("typing", "Literal") -> "Literal" """ if path == "typing_extensions": stabilization = { "TypeAlias": (3, 10), } assert name in stabilization if not self.typing_extensions_min or self.typing_extensions_min < stabilization[name]: self.typing_extensions_min = stabilization[name] return "typing_extensions." + name imp = path.replace("/", ".") if self.readable_stubs: self.from_imports[imp].add((name, None)) return name else: self.imports.add(imp) return imp + "." + name def _import_message(self, name: str) -> str: """Import a referenced message and return a handle""" message_fd = self.descriptors.message_to_fd[name] assert message_fd.name.endswith(".proto") # Strip off package name if message_fd.package: assert name.startswith("." + message_fd.package + ".") name = name[len("." + message_fd.package + ".") :] else: assert name.startswith(".") name = name[1:] # Use prepended "_r_" to disambiguate message names that alias python reserved keywords split = name.split(".") for i, part in enumerate(split): if part in PYTHON_RESERVED: split[i] = "_r_" + part name = ".".join(split) # Message defined in this file. Note: GRPC stubs in same .proto are generated into separate files if not self.grpc and message_fd.name == self.fd.name: return name if self.readable_stubs else _mangle_global_identifier(name) # Not in file. Must import # Python generated code ignores proto packages, so the only relevant factor is # whether it is in the file or not. import_name = self._import(message_fd.name[:-6].replace("-", "_") + "_pb2", split[0]) remains = ".".join(split[1:]) if not remains: return import_name # remains could either be a direct import of a nested enum or message # from another package. return import_name + "." + remains def _builtin(self, name: str) -> str: return self._import("builtins", name) @contextmanager def _indent(self) -> Iterator[None]: self.indent = self.indent + " " yield self.indent = self.indent[:-4] def _write_line(self, line: str, *args: Any) -> None: if args: line = line.format(*args) if line == "": self.lines.append(line) else: self.lines.append(self.indent + line) def _break_text(self, text_block: str) -> List[str]: if text_block == "": return [] return [line[1:] if line.startswith(" ") else line for line in text_block.rstrip().split("\n")] def _has_comments(self, scl: SourceCodeLocation) -> bool: sci_loc = self.source_code_info_by_scl.get(tuple(scl)) return sci_loc is not None and bool(sci_loc.leading_detached_comments or sci_loc.leading_comments or sci_loc.trailing_comments) def _write_comments(self, scl: SourceCodeLocation) -> bool: """Return true if any comments were written""" if not self._has_comments(scl): return False sci_loc = self.source_code_info_by_scl.get(tuple(scl)) assert sci_loc is not None leading_detached_lines = [] leading_lines = [] trailing_lines = [] for leading_detached_comment in sci_loc.leading_detached_comments: leading_detached_lines = self._break_text(leading_detached_comment) if sci_loc.leading_comments is not None: leading_lines = self._break_text(sci_loc.leading_comments) # Trailing comments also go in the header - to make sure it gets into the docstring if sci_loc.trailing_comments is not None: trailing_lines = self._break_text(sci_loc.trailing_comments) lines = leading_detached_lines if leading_detached_lines and (leading_lines or trailing_lines): lines.append("") lines.extend(leading_lines) lines.extend(trailing_lines) lines = [ # Escape triple-quotes that would otherwise end the docstring early. line.replace("\\", "\\\\").replace('"""', r"\"\"\"") for line in lines ] if len(lines) == 1: line = lines[0] if line.endswith(('"', "\\")): # Docstrings are terminated with triple-quotes, so if the documentation itself ends in a quote, # insert some whitespace to separate it from the closing quotes. # This is not necessary with multiline comments # because in that case we always insert a newline before the trailing triple-quotes. line = line + " " self._write_line(f'"""{line}"""') else: for i, line in enumerate(lines): if i == 0: self._write_line(f'"""{line}') else: self._write_line(f"{line}") self._write_line('"""') return True def write_enum_values( self, values: Iterable[Tuple[int, d.EnumValueDescriptorProto]], value_type: str, scl_prefix: SourceCodeLocation, ) -> None: for i, val in values: if val.name in PYTHON_RESERVED: continue scl = scl_prefix + [i] self._write_line( f"{val.name}: {value_type} # {val.number}", ) self._write_comments(scl) def write_module_attributes(self) -> None: wl = self._write_line fd_type = self._import("google.protobuf.descriptor", "FileDescriptor") wl(f"DESCRIPTOR: {fd_type}") wl("") def write_enums( self, enums: Iterable[d.EnumDescriptorProto], prefix: str, scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line for i, enum in enumerate(enums): class_name = enum.name if enum.name not in PYTHON_RESERVED else "_r_" + enum.name value_type_fq = prefix + class_name + ".ValueType" enum_helper_class = "_" + enum.name value_type_helper_fq = prefix + enum_helper_class + ".ValueType" etw_helper_class = "_" + enum.name + "EnumTypeWrapper" scl = scl_prefix + [i] wl(f"class {enum_helper_class}:") with self._indent(): wl( 'ValueType = {}("ValueType", {})', self._import("typing", "NewType"), self._builtin("int"), ) # Alias to the classic shorter definition "V" wl("V: {} = ValueType", self._import("typing_extensions", "TypeAlias")) wl("") wl( "class {}({}[{}], {}):", etw_helper_class, self._import("google.protobuf.internal.enum_type_wrapper", "_EnumTypeWrapper"), value_type_helper_fq, self._builtin("type"), ) with self._indent(): ed = self._import("google.protobuf.descriptor", "EnumDescriptor") wl(f"DESCRIPTOR: {ed}") self.write_enum_values( [(i, v) for i, v in enumerate(enum.value) if v.name not in PROTO_ENUM_RESERVED], value_type_helper_fq, scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER], ) wl("") if self._has_comments(scl): wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}):") with self._indent(): self._write_comments(scl) wl("") else: wl(f"class {class_name}({enum_helper_class}, metaclass={etw_helper_class}): ...") if prefix == "": wl("") self.write_enum_values( enumerate(enum.value), value_type_fq, scl + [d.EnumDescriptorProto.VALUE_FIELD_NUMBER], ) if prefix == "" and not self.readable_stubs: wl(f"{_mangle_global_identifier(class_name)} = {class_name}") wl("") def write_messages( self, messages: Iterable[d.DescriptorProto], prefix: str, scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line for i, desc in enumerate(messages): qualified_name = prefix + desc.name # Reproduce some hardcoded logic from the protobuf implementation - where # some specific "well_known_types" generated protos to have additional # base classes addl_base = "" if self.fd.package + "." + desc.name in WKTBASES: # chop off the .proto - and import the well known type # eg `from google.protobuf.duration import Duration` well_known_type = WKTBASES[self.fd.package + "." + desc.name] addl_base = ", " + self._import( "google.protobuf.internal.well_known_types", well_known_type.__name__, ) class_name = desc.name if desc.name not in PYTHON_RESERVED else "_r_" + desc.name message_class = self._import("google.protobuf.message", "Message") wl("@{}", self._import("typing", "final")) wl(f"class {class_name}({message_class}{addl_base}):") with self._indent(): scl = scl_prefix + [i] if self._write_comments(scl): wl("") desc_type = self._import("google.protobuf.descriptor", "Descriptor") wl(f"DESCRIPTOR: {desc_type}") wl("") # Nested enums/messages self.write_enums( desc.enum_type, qualified_name + ".", scl + [d.DescriptorProto.ENUM_TYPE_FIELD_NUMBER], ) self.write_messages( desc.nested_type, qualified_name + ".", scl + [d.DescriptorProto.NESTED_TYPE_FIELD_NUMBER], ) # integer constants for field numbers for f in desc.field: wl(f"{f.name.upper()}_FIELD_NUMBER: {self._builtin('int')}") for idx, field in enumerate(desc.field): if field.name in PYTHON_RESERVED: continue field_type = self.python_type(field) if is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED: # Scalar non repeated fields are r/w wl(f"{field.name}: {field_type}") self._write_comments(scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx]) for idx, field in enumerate(desc.field): if field.name in PYTHON_RESERVED: continue field_type = self.python_type(field) if not (is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED): # r/o Getters for non-scalar fields and scalar-repeated fields scl_field = scl + [d.DescriptorProto.FIELD_FIELD_NUMBER, idx] wl("@property") body = " ..." if not self._has_comments(scl_field) else "" wl(f"def {field.name}(self) -> {field_type}:{body}") if self._has_comments(scl_field): with self._indent(): self._write_comments(scl_field) wl("") self.write_extensions(desc.extension, scl + [d.DescriptorProto.EXTENSION_FIELD_NUMBER]) # Constructor wl("def __init__(") with self._indent(): if any(f.name == "self" for f in desc.field): wl("self_, # pyright: ignore[reportSelfClsParameterName]") else: wl("self,") with self._indent(): constructor_fields = [f for f in desc.field if f.name not in PYTHON_RESERVED] if len(constructor_fields) > 0: # Only positional args allowed # See https://github.com/nipunn1313/mypy-protobuf/issues/71 wl("*,") for field in constructor_fields: field_type = self.python_type(field, generic_container=True) if self.fd.syntax == "proto3" and is_scalar(field) and field.label != d.FieldDescriptorProto.LABEL_REPEATED and not self.relax_strict_optional_primitives and not field.proto3_optional: wl(f"{field.name}: {field_type} = ...,") else: wl(f"{field.name}: {field_type} | None = ...,") wl(") -> None: ...") self.write_stringly_typed_fields(desc) if prefix == "" and not self.readable_stubs: wl("") wl(f"{_mangle_global_identifier(class_name)} = {class_name}") wl("") def write_stringly_typed_fields(self, desc: d.DescriptorProto) -> None: """Type the stringly-typed methods as a Union[Literal, Literal ...]""" wl = self._write_line # HasField, ClearField, WhichOneof accepts both bytes/str # HasField only supports singular. ClearField supports repeated as well # In proto3, HasField only supports message fields and optional fields # HasField always supports oneof fields hf_fields = [f.name for f in desc.field if f.HasField("oneof_index") or (f.label != d.FieldDescriptorProto.LABEL_REPEATED and (self.fd.syntax != "proto3" or f.type == d.FieldDescriptorProto.TYPE_MESSAGE or f.proto3_optional))] cf_fields = [f.name for f in desc.field] wo_fields = {oneof.name: [f.name for f in desc.field if f.HasField("oneof_index") and f.oneof_index == idx] for idx, oneof in enumerate(desc.oneof_decl)} hf_fields.extend(wo_fields.keys()) cf_fields.extend(wo_fields.keys()) hf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in hf_fields)) cf_fields_text = ", ".join(sorted(f'"{name}", b"{name}"' for name in cf_fields)) if not hf_fields and not cf_fields and not wo_fields: return if hf_fields: wl( "def HasField(self, field_name: {}[{}]) -> {}: ...", self._import("typing", "Literal"), hf_fields_text, self._builtin("bool"), ) if cf_fields: wl( "def ClearField(self, field_name: {}[{}]) -> None: ...", self._import("typing", "Literal"), cf_fields_text, ) for wo_field, members in sorted(wo_fields.items()): if len(wo_fields) > 1: wl("@{}", self._import("typing", "overload")) wl( "def WhichOneof(self, oneof_group: {}[{}]) -> {}[{}] | None: ...", self._import("typing", "Literal"), # Accepts both str and bytes f'"{wo_field}", b"{wo_field}"', self._import("typing", "Literal"), # Returns `str` ", ".join(f'"{m}"' for m in members), ) def write_extensions( self, extensions: Sequence[d.FieldDescriptorProto], scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line for ext in extensions: wl(f"{ext.name.upper()}_FIELD_NUMBER: {self._builtin('int')}") for i, ext in enumerate(extensions): scl = scl_prefix + [i] wl( "{}: {}[{}, {}]", ext.name, self._import( "google.protobuf.internal.extension_dict", "_ExtensionFieldDescriptor", ), self._import_message(ext.extendee), self.python_type(ext), ) self._write_comments(scl) def write_methods( self, service: d.ServiceDescriptorProto, class_name: str, is_abstract: bool, scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line wl( "DESCRIPTOR: {}", self._import("google.protobuf.descriptor", "ServiceDescriptor"), ) methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED] if not methods: wl("...") for i, method in methods: if is_abstract: wl("@{}", self._import("abc", "abstractmethod")) wl(f"def {method.name}(") with self._indent(): wl(f"inst: {class_name}, # pyright: ignore[reportSelfClsParameterName]") wl( "rpc_controller: {},", self._import("google.protobuf.service", "RpcController"), ) wl("request: {},", self._import_message(method.input_type)) wl( "callback: {}[[{}], None] | None{},", self._import("collections.abc", "Callable"), self._import_message(method.output_type), "" if is_abstract else " = ...", ) scl_method = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] wl( ") -> {}[{}]:{}", self._import("concurrent.futures", "Future"), self._import_message(method.output_type), " ..." if not self._has_comments(scl_method) else "", ) if self._has_comments(scl_method): with self._indent(): if not self._write_comments(scl_method): wl("...") wl("") def write_services( self, services: Iterable[d.ServiceDescriptorProto], scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line for i, service in enumerate(services): scl = scl_prefix + [i] class_name = service.name if service.name not in PYTHON_RESERVED else "_r_" + service.name # The service definition interface wl( "class {}({}, metaclass={}):", class_name, self._import("google.protobuf.service", "Service"), self._import("abc", "ABCMeta"), ) with self._indent(): if self._write_comments(scl): wl("") self.write_methods(service, class_name, is_abstract=True, scl_prefix=scl) # The stub client stub_class_name = service.name + "_Stub" wl("class {}({}):", stub_class_name, class_name) with self._indent(): if self._write_comments(scl): wl("") wl( "def __init__(self, rpc_channel: {}) -> None: ...", self._import("google.protobuf.service", "RpcChannel"), ) self.write_methods(service, stub_class_name, is_abstract=False, scl_prefix=scl) def _import_casttype(self, casttype: str) -> str: split = casttype.split(".") assert len(split) == 2, "mypy_protobuf.[casttype,keytype,valuetype] is expected to be of format path/to/file.TypeInFile" pkg = split[0].replace("/", ".") return self._import(pkg, split[1]) def _map_key_value_types( self, map_field: d.FieldDescriptorProto, key_field: d.FieldDescriptorProto, value_field: d.FieldDescriptorProto, ) -> Tuple[str, str]: oldstyle_keytype = map_field.options.Extensions[extensions_pb2.keytype] if oldstyle_keytype: print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.keytype) is deprecated. Prefer (mypy_protobuf.options).keytype", file=sys.stderr) key_casttype = map_field.options.Extensions[extensions_pb2.options].keytype or oldstyle_keytype ktype = self._import_casttype(key_casttype) if key_casttype else self.python_type(key_field) oldstyle_valuetype = map_field.options.Extensions[extensions_pb2.valuetype] if oldstyle_valuetype: print(f"Warning: Map Field {map_field.name}: (mypy_protobuf.valuetype) is deprecated. Prefer (mypy_protobuf.options).valuetype", file=sys.stderr) value_casttype = map_field.options.Extensions[extensions_pb2.options].valuetype or map_field.options.Extensions[extensions_pb2.valuetype] vtype = self._import_casttype(value_casttype) if value_casttype else self.python_type(value_field) return ktype, vtype def _callable_type(self, method: d.MethodDescriptorProto, is_async: bool = False) -> str: module = "grpc.aio" if is_async else "grpc" if method.client_streaming: if method.server_streaming: return self._import(module, "StreamStreamMultiCallable") else: return self._import(module, "StreamUnaryMultiCallable") else: if method.server_streaming: return self._import(module, "UnaryStreamMultiCallable") else: return self._import(module, "UnaryUnaryMultiCallable") def _input_type(self, method: d.MethodDescriptorProto) -> str: result = self._import_message(method.input_type) return result def _servicer_input_type(self, method: d.MethodDescriptorProto) -> str: result = self._import_message(method.input_type) if method.client_streaming: # See write_grpc_async_hacks(). result = f"_MaybeAsyncIterator[{result}]" return result def _output_type(self, method: d.MethodDescriptorProto) -> str: result = self._import_message(method.output_type) return result def _servicer_output_type(self, method: d.MethodDescriptorProto) -> str: result = self._import_message(method.output_type) if method.server_streaming: # Union[Iterator[Resp], AsyncIterator[Resp]] is subtyped by Iterator[Resp] and AsyncIterator[Resp]. # So both can be used in the covariant function return position. iterator = f"{self._import('collections.abc', 'Iterator')}[{result}]" aiterator = f"{self._import('collections.abc', 'AsyncIterator')}[{result}]" result = f"{self._import('typing', 'Union')}[{iterator}, {aiterator}]" else: # Union[Resp, Awaitable[Resp]] is subtyped by Resp and Awaitable[Resp]. # So both can be used in the covariant function return position. # Awaitable[Resp] is equivalent to async def. awaitable = f"{self._import('collections.abc', 'Awaitable')}[{result}]" result = f"{self._import('typing', 'Union')}[{result}, {awaitable}]" return result def write_grpc_async_hacks(self) -> None: wl = self._write_line # _MaybeAsyncIterator[Req] is supertyped by Iterator[Req] and AsyncIterator[Req]. # So both can be used in the contravariant function parameter position. wl('_T = {}("_T")', self._import("typing", "TypeVar")) wl("") wl( "class _MaybeAsyncIterator({}[_T], {}[_T], metaclass={}): ...", self._import("collections.abc", "AsyncIterator"), self._import("collections.abc", "Iterator"), self._import("abc", "ABCMeta"), ) wl("") # _ServicerContext is supertyped by grpc.ServicerContext and grpc.aio.ServicerContext # So both can be used in the contravariant function parameter position. wl( "class _ServicerContext({}, {}): # type: ignore[misc, type-arg]", self._import("grpc", "ServicerContext"), self._import("grpc.aio", "ServicerContext"), ) with self._indent(): wl("...") wl("") def write_grpc_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation) -> None: wl = self._write_line methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED] if not methods: wl("...") wl("") for i, method in methods: scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] wl("@{}", self._import("abc", "abstractmethod")) wl("def {}(", method.name) with self._indent(): wl("self,") input_name = "request_iterator" if method.client_streaming else "request" input_type = self._servicer_input_type(method) wl(f"{input_name}: {input_type},") wl("context: _ServicerContext,") wl( ") -> {}:{}", self._servicer_output_type(method), " ..." if not self._has_comments(scl) else "", ) if self._has_comments(scl): with self._indent(): if not self._write_comments(scl): wl("...") wl("") def write_grpc_stub_methods(self, service: d.ServiceDescriptorProto, scl_prefix: SourceCodeLocation, is_async: bool = False) -> None: wl = self._write_line methods = [(i, m) for i, m in enumerate(service.method) if m.name not in PYTHON_RESERVED] if not methods: wl("...") wl("") for i, method in methods: scl = scl_prefix + [d.ServiceDescriptorProto.METHOD_FIELD_NUMBER, i] wl("{}: {}[", method.name, self._callable_type(method, is_async=is_async)) with self._indent(): wl("{},", self._input_type(method)) wl("{},", self._output_type(method)) wl("]") self._write_comments(scl) wl("") def write_grpc_services( self, services: Iterable[d.ServiceDescriptorProto], scl_prefix: SourceCodeLocation, ) -> None: wl = self._write_line for i, service in enumerate(services): if service.name in PYTHON_RESERVED: continue scl = scl_prefix + [i] # The stub client wl( "class {}Stub:", service.name, ) with self._indent(): if self._write_comments(scl): wl("") # To support casting into FooAsyncStub, allow both Channel and aio.Channel here. channel = f"{self._import('typing', 'Union')}[{self._import('grpc', 'Channel')}, {self._import('grpc.aio', 'Channel')}]" wl("def __init__(self, channel: {}) -> None: ...", channel) self.write_grpc_stub_methods(service, scl) # The (fake) async stub client wl( "class {}AsyncStub:", service.name, ) with self._indent(): if self._write_comments(scl): wl("") # No __init__ since this isn't a real class (yet), and requires manual casting to work. self.write_grpc_stub_methods(service, scl, is_async=True) # The service definition interface wl( "class {}Servicer(metaclass={}):", service.name, self._import("abc", "ABCMeta"), ) with self._indent(): if self._write_comments(scl): wl("") self.write_grpc_methods(service, scl) server = self._import("grpc", "Server") aserver = self._import("grpc.aio", "Server") wl( "def add_{}Servicer_to_server(servicer: {}Servicer, server: {}) -> None: ...", service.name, service.name, f"{self._import('typing', 'Union')}[{server}, {aserver}]", ) wl("") def python_type(self, field: d.FieldDescriptorProto, generic_container: bool = False) -> str: """ generic_container if set, type the field with generic interfaces. Eg. - Iterable[int] rather than RepeatedScalarFieldContainer[int] - Mapping[k, v] rather than MessageMap[k, v] Can be useful for input types (eg constructor) """ oldstyle_casttype = field.options.Extensions[extensions_pb2.casttype] if oldstyle_casttype: print(f"Warning: Field {field.name}: (mypy_protobuf.casttype) is deprecated. Prefer (mypy_protobuf.options).casttype", file=sys.stderr) casttype = field.options.Extensions[extensions_pb2.options].casttype or oldstyle_casttype if casttype: return self._import_casttype(casttype) mapping: Dict[d.FieldDescriptorProto.Type.V, Callable[[], str]] = { d.FieldDescriptorProto.TYPE_DOUBLE: lambda: self._builtin("float"), d.FieldDescriptorProto.TYPE_FLOAT: lambda: self._builtin("float"), d.FieldDescriptorProto.TYPE_INT64: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_UINT64: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_FIXED64: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_SFIXED64: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_SINT64: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_INT32: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_UINT32: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_FIXED32: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_SFIXED32: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_SINT32: lambda: self._builtin("int"), d.FieldDescriptorProto.TYPE_BOOL: lambda: self._builtin("bool"), d.FieldDescriptorProto.TYPE_STRING: lambda: self._builtin("str"), d.FieldDescriptorProto.TYPE_BYTES: lambda: self._builtin("bytes"), d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message(field.type_name + ".ValueType"), d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message(field.type_name), d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message(field.type_name), } assert field.type in mapping, "Unrecognized type: " + repr(field.type) field_type = mapping[field.type]() # For non-repeated fields, we're done! if field.label != d.FieldDescriptorProto.LABEL_REPEATED: return field_type # Scalar repeated fields go in RepeatedScalarFieldContainer if is_scalar(field): container = ( self._import("collections.abc", "Iterable") if generic_container else self._import( "google.protobuf.internal.containers", "RepeatedScalarFieldContainer", ) ) return f"{container}[{field_type}]" # non-scalar repeated map fields go in ScalarMap/MessageMap msg = self.descriptors.messages[field.type_name] if msg.options.map_entry: # map generates a special Entry wrapper message if generic_container: container = self._import("collections.abc", "Mapping") elif is_scalar(msg.field[1]): container = self._import("google.protobuf.internal.containers", "ScalarMap") else: container = self._import("google.protobuf.internal.containers", "MessageMap") ktype, vtype = self._map_key_value_types(field, msg.field[0], msg.field[1]) return f"{container}[{ktype}, {vtype}]" # non-scalar repetated fields go in RepeatedCompositeFieldContainer container = ( self._import("collections.abc", "Iterable") if generic_container else self._import( "google.protobuf.internal.containers", "RepeatedCompositeFieldContainer", ) ) return f"{container}[{field_type}]" def write(self) -> str: # save current module content, so that imports and module docstring can be inserted saved_lines = self.lines self.lines = [] # module docstring may exist as comment before syntax (optional) or package name if not self._write_comments([d.FileDescriptorProto.PACKAGE_FIELD_NUMBER]): self._write_comments([d.FileDescriptorProto.SYNTAX_FIELD_NUMBER]) if self.lines: assert self.lines[0].startswith('"""') self.lines[0] = f'"""{HEADER}{self.lines[0][3:]}' self._write_line("") else: self._write_line(f'"""{HEADER}"""\n') for reexport_idx in self.fd.public_dependency: reexport_file = self.fd.dependency[reexport_idx] reexport_fd = self.descriptors.files[reexport_file] reexport_imp = reexport_file[:-6].replace("-", "_").replace("/", ".") + "_pb2" names = [m.name for m in reexport_fd.message_type] + [m.name for m in reexport_fd.enum_type] + [v.name for m in reexport_fd.enum_type for v in m.value] + [m.name for m in reexport_fd.extension] if reexport_fd.options.py_generic_services: names.extend(m.name for m in reexport_fd.service) if names: # n,n to force a reexport (from x import y as y) self.from_imports[reexport_imp].update((n, n) for n in names) if self.typing_extensions_min: self.imports.add("sys") for pkg in sorted(self.imports): self._write_line(f"import {pkg}") if self.typing_extensions_min: self._write_line("") self._write_line(f"if sys.version_info >= {self.typing_extensions_min}:") self._write_line(" import typing as typing_extensions") self._write_line("else:") self._write_line(" import typing_extensions") for pkg, items in sorted(self.from_imports.items()): self._write_line(f"from {pkg} import (") for name, reexport_name in sorted(items): if reexport_name is None: self._write_line(f" {name},") else: self._write_line(f" {name} as {reexport_name},") self._write_line(")") self._write_line("") # restore module content self.lines += saved_lines content = "\n".join(self.lines) if not content.endswith("\n"): content = content + "\n" return content def is_scalar(fd: d.FieldDescriptorProto) -> bool: return not (fd.type == d.FieldDescriptorProto.TYPE_MESSAGE or fd.type == d.FieldDescriptorProto.TYPE_GROUP) def generate_mypy_stubs( descriptors: Descriptors, response: plugin_pb2.CodeGeneratorResponse, quiet: bool, readable_stubs: bool, relax_strict_optional_primitives: bool, ) -> None: for name, fd in descriptors.to_generate.items(): pkg_writer = PkgWriter( fd, descriptors, readable_stubs, relax_strict_optional_primitives, grpc=False, ) pkg_writer.write_module_attributes() pkg_writer.write_enums(fd.enum_type, "", [d.FileDescriptorProto.ENUM_TYPE_FIELD_NUMBER]) pkg_writer.write_messages(fd.message_type, "", [d.FileDescriptorProto.MESSAGE_TYPE_FIELD_NUMBER]) pkg_writer.write_extensions(fd.extension, [d.FileDescriptorProto.EXTENSION_FIELD_NUMBER]) if fd.options.py_generic_services: pkg_writer.write_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]) assert name == fd.name assert fd.name.endswith(".proto") output = response.file.add() output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2.pyi" output.content = pkg_writer.write() def generate_mypy_grpc_stubs( descriptors: Descriptors, response: plugin_pb2.CodeGeneratorResponse, quiet: bool, readable_stubs: bool, relax_strict_optional_primitives: bool, ) -> None: for name, fd in descriptors.to_generate.items(): pkg_writer = PkgWriter( fd, descriptors, readable_stubs, relax_strict_optional_primitives, grpc=True, ) pkg_writer.write_grpc_async_hacks() pkg_writer.write_grpc_services(fd.service, [d.FileDescriptorProto.SERVICE_FIELD_NUMBER]) assert name == fd.name assert fd.name.endswith(".proto") output = response.file.add() output.name = fd.name[:-6].replace("-", "_").replace(".", "/") + "_pb2_grpc.pyi" output.content = pkg_writer.write() @contextmanager def code_generation() -> Iterator[Tuple[plugin_pb2.CodeGeneratorRequest, plugin_pb2.CodeGeneratorResponse],]: if len(sys.argv) > 1 and sys.argv[1] in ("-V", "--version"): print("mypy-protobuf " + __version__) sys.exit(0) # Read request message from stdin data = sys.stdin.buffer.read() # Parse request request = plugin_pb2.CodeGeneratorRequest() request.ParseFromString(data) # Create response response = plugin_pb2.CodeGeneratorResponse() # Declare support for optional proto3 fields response.supported_features |= plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL yield request, response # Serialise response message output = response.SerializeToString() # Write to stdout sys.stdout.buffer.write(output) def main() -> None: # Generate mypy with code_generation() as (request, response): generate_mypy_stubs( Descriptors(request), response, "quiet" in request.parameter, "readable_stubs" in request.parameter, "relax_strict_optional_primitives" in request.parameter, ) def grpc() -> None: # Generate grpc mypy with code_generation() as (request, response): generate_mypy_grpc_stubs( Descriptors(request), response, "quiet" in request.parameter, "readable_stubs" in request.parameter, "relax_strict_optional_primitives" in request.parameter, ) if __name__ == "__main__": main()