_multicast.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # -*- test-case-name: twisted.test.test_udp -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. from __future__ import annotations
  5. import socket
  6. import struct
  7. from typing import Any
  8. from twisted.internet.abstract import isIPAddress, isIPv6Address
  9. from twisted.internet.defer import Deferred, succeed
  10. from twisted.internet.error import MulticastJoinError
  11. from twisted.internet.interfaces import IReactorCore
  12. def _maybeResolve(reactor: IReactorCore, addr: str) -> Deferred[str]:
  13. if isIPv6Address(addr) or isIPAddress(addr):
  14. return succeed(addr)
  15. return reactor.resolve(addr)
  16. class MulticastMixin:
  17. """
  18. Implement multicast functionality.
  19. """
  20. addressFamily: socket.AddressFamily
  21. reactor: Any
  22. socket: socket.socket
  23. def _addrpack(self, addr: str) -> bytes:
  24. """
  25. Pack an IP address literal into bytes, according to the address family
  26. of this transport.
  27. """
  28. try:
  29. return socket.inet_pton(self.addressFamily, addr)
  30. except OSError:
  31. raise MulticastJoinError(
  32. f"invalid address literal for {socket.AddressFamily(self.addressFamily).name}: {addr!r}"
  33. )
  34. @property
  35. def _ipproto(self) -> int:
  36. return (
  37. socket.IPPROTO_IP
  38. if self.addressFamily == socket.AF_INET
  39. else socket.IPPROTO_IPV6
  40. )
  41. @property
  42. def _multiloop(self) -> int:
  43. return (
  44. socket.IP_MULTICAST_LOOP
  45. if self.addressFamily == socket.AF_INET
  46. else socket.IPV6_MULTICAST_LOOP
  47. )
  48. @property
  49. def _multiif(self) -> int:
  50. return (
  51. socket.IP_MULTICAST_IF
  52. if self.addressFamily == socket.AF_INET
  53. else socket.IPV6_MULTICAST_IF
  54. )
  55. @property
  56. def _joingroup(self) -> int:
  57. return (
  58. socket.IP_ADD_MEMBERSHIP
  59. if self.addressFamily == socket.AF_INET
  60. else socket.IPV6_JOIN_GROUP
  61. )
  62. @property
  63. def _leavegroup(self) -> int:
  64. return (
  65. socket.IP_DROP_MEMBERSHIP
  66. if self.addressFamily == socket.AF_INET
  67. else socket.IPV6_LEAVE_GROUP
  68. )
  69. def getOutgoingInterface(self) -> str | int:
  70. blen = 0x4 if self.addressFamily == socket.AF_INET else 0x10
  71. ipproto = self._ipproto
  72. multiif = self._multiif
  73. i = self.socket.getsockopt(ipproto, multiif, blen)
  74. from sys import byteorder
  75. if self.addressFamily == socket.AF_INET6:
  76. return int.from_bytes(i, byteorder)
  77. return socket.inet_ntop(self.addressFamily, i)
  78. def setOutgoingInterface(self, addr: str | int) -> Deferred[int]:
  79. """
  80. @see: L{IMulticastTransport.setOutgoingInterface}
  81. """
  82. async def asynchronously() -> int:
  83. i: bytes | int
  84. if self.addressFamily == socket.AF_INET:
  85. assert isinstance(
  86. addr, str
  87. ), "IPv4 interfaces are specified as addresses"
  88. i = self._addrpack(await _maybeResolve(self.reactor, addr))
  89. else:
  90. assert isinstance(
  91. addr, int
  92. ), "IPv6 interfaces are specified as integers"
  93. i = addr
  94. self.socket.setsockopt(self._ipproto, self._multiif, i)
  95. return 1
  96. return Deferred.fromCoroutine(asynchronously())
  97. def getLoopbackMode(self) -> bool:
  98. return bool(self.socket.getsockopt(self._ipproto, self._multiloop))
  99. def setLoopbackMode(self, mode: int) -> None:
  100. # mode = struct.pack("b", bool(mode))
  101. a = self._ipproto
  102. b = self._multiloop
  103. self.socket.setsockopt(a, b, int(bool(mode)))
  104. def getTTL(self) -> int:
  105. return self.socket.getsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL)
  106. def setTTL(self, ttl: int) -> None:
  107. bttl = struct.pack("B", ttl)
  108. self.socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, bttl)
  109. def _joinleave(self, addr: str, interface: str, join: bool) -> Deferred[None]:
  110. cmd = self._joingroup if join else self._leavegroup
  111. if not interface:
  112. interface = "0.0.0.0" if self.addressFamily == socket.AF_INET else "::"
  113. async def impl() -> None:
  114. resaddr = await _maybeResolve(self.reactor, addr)
  115. resif = await _maybeResolve(self.reactor, interface)
  116. packaddr = self._addrpack(resaddr)
  117. packif = self._addrpack(resif)
  118. try:
  119. self.socket.setsockopt(self._ipproto, cmd, packaddr + packif)
  120. except OSError as e:
  121. raise MulticastJoinError(addr, interface, *e.args) from e
  122. return Deferred.fromCoroutine(impl())
  123. def joinGroup(self, addr: str, interface: str = "") -> Deferred[None]:
  124. """
  125. @see: L{IMulticastTransport.joinGroup}
  126. """
  127. return self._joinleave(addr, interface, True)
  128. def leaveGroup(self, addr: str, interface: str = "") -> Deferred[None]:
  129. """
  130. @see: L{IMulticastTransport.leaveGroup}
  131. """
  132. return self._joinleave(addr, interface, False)