MarlinBinaryProtocol.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. #
  2. # MarlinBinaryProtocol.py
  3. # Supporting Firmware upload via USB/Serial, saving to the attached media.
  4. #
  5. import serial, math, time, threading, sys, datetime, random
  6. from collections import deque
  7. try:
  8. import heatshrink2 as heatshrink
  9. heatshrink_exists = True
  10. except ImportError:
  11. try:
  12. import heatshrink
  13. heatshrink_exists = True
  14. except ImportError:
  15. heatshrink_exists = False
  16. def millis():
  17. return time.perf_counter() * 1000
  18. class TimeOut(object):
  19. def __init__(self, milliseconds):
  20. self.duration = milliseconds
  21. self.reset()
  22. def reset(self):
  23. self.endtime = millis() + self.duration
  24. def timedout(self):
  25. return millis() > self.endtime
  26. class ReadTimeout(Exception):
  27. pass
  28. class FatalError(Exception):
  29. pass
  30. class SycronisationError(Exception):
  31. pass
  32. class PayloadOverflow(Exception):
  33. pass
  34. class ConnectionLost(Exception):
  35. pass
  36. class Protocol(object):
  37. device = None
  38. baud = None
  39. max_block_size = 0
  40. port = None
  41. block_size = 0
  42. packet_transit = None
  43. packet_status = None
  44. packet_ping = None
  45. errors = 0
  46. packet_buffer = None
  47. simulate_errors = 0
  48. sync = 0
  49. connected = False
  50. syncronised = False
  51. worker_thread = None
  52. response_timeout = 1000
  53. applications = []
  54. responses = deque()
  55. def __init__(self, device, baud, bsize, simerr, timeout):
  56. print("pySerial Version:", serial.VERSION)
  57. self.port = serial.Serial(device, baudrate = baud, write_timeout = 0, timeout = 1)
  58. self.device = device
  59. self.baud = baud
  60. self.block_size = int(bsize)
  61. self.simulate_errors = max(min(simerr, 1.0), 0.0)
  62. self.connected = True
  63. self.response_timeout = timeout
  64. self.register(['ok', 'rs', 'ss', 'fe'], self.process_input)
  65. self.worker_thread = threading.Thread(target=Protocol.receive_worker, args=(self,))
  66. self.worker_thread.start()
  67. def receive_worker(self):
  68. while self.port.in_waiting:
  69. self.port.reset_input_buffer()
  70. def dispatch(data):
  71. for tokens, callback in self.applications:
  72. for token in tokens:
  73. if token == data[:len(token)]:
  74. callback((token, data[len(token):]))
  75. return
  76. def reconnect():
  77. print("Reconnecting..")
  78. self.port.close()
  79. for x in range(10):
  80. try:
  81. if self.connected:
  82. self.port = serial.Serial(self.device, baudrate = self.baud, write_timeout = 0, timeout = 1)
  83. return
  84. else:
  85. print("Connection closed")
  86. return
  87. except:
  88. time.sleep(1)
  89. raise ConnectionLost()
  90. while self.connected:
  91. try:
  92. data = self.port.readline().decode('utf8').rstrip()
  93. if len(data):
  94. #print(data)
  95. dispatch(data)
  96. except OSError:
  97. reconnect()
  98. except UnicodeDecodeError:
  99. # dodgy client output or datastream corruption
  100. self.port.reset_input_buffer()
  101. def shutdown(self):
  102. self.connected = False
  103. self.worker_thread.join()
  104. self.port.close()
  105. def process_input(self, data):
  106. #print(data)
  107. self.responses.append(data)
  108. def register(self, tokens, callback):
  109. self.applications.append((tokens, callback))
  110. def send(self, protocol, packet_type, data = bytearray()):
  111. self.packet_transit = self.build_packet(protocol, packet_type, data)
  112. self.packet_status = 0
  113. self.transmit_attempt = 0
  114. timeout = TimeOut(self.response_timeout * 20)
  115. while self.packet_status == 0:
  116. try:
  117. if timeout.timedout():
  118. raise ConnectionLost()
  119. self.transmit_packet(self.packet_transit)
  120. self.await_response()
  121. except ReadTimeout:
  122. self.errors += 1
  123. #print("Packetloss detected..")
  124. self.packet_transit = None
  125. def await_response(self):
  126. timeout = TimeOut(self.response_timeout)
  127. while not len(self.responses):
  128. time.sleep(0.00001)
  129. if timeout.timedout():
  130. raise ReadTimeout()
  131. while len(self.responses):
  132. token, data = self.responses.popleft()
  133. switch = {'ok' : self.response_ok, 'rs': self.response_resend, 'ss' : self.response_stream_sync, 'fe' : self.response_fatal_error}
  134. switch[token](data)
  135. def send_ascii(self, data, send_and_forget = False):
  136. self.packet_transit = bytearray(data, "utf8") + b'\n'
  137. self.packet_status = 0
  138. self.transmit_attempt = 0
  139. timeout = TimeOut(self.response_timeout * 20)
  140. while self.packet_status == 0:
  141. try:
  142. if timeout.timedout():
  143. return
  144. self.port.write(self.packet_transit)
  145. if send_and_forget:
  146. self.packet_status = 1
  147. else:
  148. self.await_response_ascii()
  149. except ReadTimeout:
  150. self.errors += 1
  151. #print("Packetloss detected..")
  152. except serial.SerialException:
  153. return
  154. self.packet_transit = None
  155. def await_response_ascii(self):
  156. timeout = TimeOut(self.response_timeout)
  157. while not len(self.responses):
  158. time.sleep(0.00001)
  159. if timeout.timedout():
  160. raise ReadTimeout()
  161. token, data = self.responses.popleft()
  162. self.packet_status = 1
  163. def corrupt_array(self, data):
  164. rid = random.randint(0, len(data) - 1)
  165. data[rid] ^= 0xAA
  166. return data
  167. def transmit_packet(self, packet):
  168. packet = bytearray(packet)
  169. if (self.simulate_errors > 0 and random.random() > (1.0 - self.simulate_errors)):
  170. if random.random() > 0.9:
  171. #random data drop
  172. start = random.randint(0, len(packet))
  173. end = start + random.randint(1, 10)
  174. packet = packet[:start] + packet[end:]
  175. #print("Dropping {0} bytes".format(end - start))
  176. else:
  177. #random corruption
  178. packet = self.corrupt_array(packet)
  179. #print("Single byte corruption")
  180. self.port.write(packet)
  181. self.transmit_attempt += 1
  182. def build_packet(self, protocol, packet_type, data = bytearray()):
  183. PACKET_TOKEN = 0xB5AD
  184. if len(data) > self.max_block_size:
  185. raise PayloadOverflow()
  186. packet_buffer = bytearray()
  187. packet_buffer += self.pack_int8(self.sync) # 8bit sync id
  188. packet_buffer += self.pack_int4_2(protocol, packet_type) # 4 bit protocol id, 4 bit packet type
  189. packet_buffer += self.pack_int16(len(data)) # 16bit packet length
  190. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer)) # 16bit header checksum
  191. if len(data):
  192. packet_buffer += data
  193. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer))
  194. packet_buffer = self.pack_int16(PACKET_TOKEN) + packet_buffer # 16bit start token, not included in checksum
  195. return packet_buffer
  196. # checksum 16 fletchers
  197. def checksum(self, cs, value):
  198. cs_low = (((cs & 0xFF) + value) % 255)
  199. return ((((cs >> 8) + cs_low) % 255) << 8) | cs_low
  200. def build_checksum(self, buffer):
  201. cs = 0
  202. for b in buffer:
  203. cs = self.checksum(cs, b)
  204. return cs
  205. def pack_int32(self, value):
  206. return value.to_bytes(4, byteorder='little')
  207. def pack_int16(self, value):
  208. return value.to_bytes(2, byteorder='little')
  209. def pack_int8(self, value):
  210. return value.to_bytes(1, byteorder='little')
  211. def pack_int4_2(self, vh, vl):
  212. value = ((vh & 0xF) << 4) | (vl & 0xF)
  213. return value.to_bytes(1, byteorder='little')
  214. def connect(self):
  215. print("Connecting: Switching Marlin to Binary Protocol...")
  216. self.send_ascii("M28B1")
  217. self.send(0, 1)
  218. def disconnect(self):
  219. self.send(0, 2)
  220. self.syncronised = False
  221. def response_ok(self, data):
  222. try:
  223. packet_id = int(data)
  224. except ValueError:
  225. return
  226. if packet_id != self.sync:
  227. raise SycronisationError()
  228. self.sync = (self.sync + 1) % 256
  229. self.packet_status = 1
  230. def response_resend(self, data):
  231. packet_id = int(data)
  232. self.errors += 1
  233. if not self.syncronised:
  234. print("Retrying syncronisation")
  235. elif packet_id != self.sync:
  236. raise SycronisationError()
  237. def response_stream_sync(self, data):
  238. sync, max_block_size, protocol_version = data.split(',')
  239. self.sync = int(sync)
  240. self.max_block_size = int(max_block_size)
  241. self.block_size = self.max_block_size if self.max_block_size < self.block_size else self.block_size
  242. self.protocol_version = protocol_version
  243. self.packet_status = 1
  244. self.syncronised = True
  245. print("Connection synced [{0}], binary protocol version {1}, {2} byte payload buffer".format(self.sync, self.protocol_version, self.max_block_size))
  246. def response_fatal_error(self, data):
  247. raise FatalError()
  248. class FileTransferProtocol(object):
  249. protocol_id = 1
  250. class Packet(object):
  251. QUERY = 0
  252. OPEN = 1
  253. CLOSE = 2
  254. WRITE = 3
  255. ABORT = 4
  256. responses = deque()
  257. def __init__(self, protocol, timeout = None):
  258. protocol.register(['PFT:success', 'PFT:version:', 'PFT:fail', 'PFT:busy', 'PFT:ioerror', 'PTF:invalid'], self.process_input)
  259. self.protocol = protocol
  260. self.response_timeout = timeout or protocol.response_timeout
  261. def process_input(self, data):
  262. #print(data)
  263. self.responses.append(data)
  264. def await_response(self, timeout = None):
  265. timeout = TimeOut(timeout or self.response_timeout)
  266. while not len(self.responses):
  267. time.sleep(0.0001)
  268. if timeout.timedout():
  269. raise ReadTimeout()
  270. return self.responses.popleft()
  271. def connect(self):
  272. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.QUERY)
  273. token, data = self.await_response()
  274. if token != 'PFT:version:':
  275. return False
  276. self.version, _, compression = data.split(':')
  277. if compression != 'none':
  278. algorithm, window, lookahead = compression.split(',')
  279. self.compression = {'algorithm': algorithm, 'window': int(window), 'lookahead': int(lookahead)}
  280. else:
  281. self.compression = {'algorithm': 'none'}
  282. print("File Transfer version: {0}, compression: {1}".format(self.version, self.compression['algorithm']))
  283. def open(self, filename, compression, dummy):
  284. payload = b'\1' if dummy else b'\0' # dummy transfer
  285. payload += b'\1' if compression else b'\0' # payload compression
  286. payload += bytearray(filename, 'utf8') + b'\0'# target filename + null terminator
  287. timeout = TimeOut(5000)
  288. token = None
  289. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload)
  290. while token != 'PFT:success' and not timeout.timedout():
  291. try:
  292. token, data = self.await_response(1000)
  293. if token == 'PFT:success':
  294. print(filename,"opened")
  295. return
  296. elif token == 'PFT:busy':
  297. print("Broken transfer detected, purging")
  298. self.abort()
  299. time.sleep(0.1)
  300. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload)
  301. timeout.reset()
  302. elif token == 'PFT:fail':
  303. raise Exception("Can not open file on client")
  304. except ReadTimeout:
  305. pass
  306. raise ReadTimeout()
  307. def write(self, data):
  308. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.WRITE, data)
  309. def close(self):
  310. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.CLOSE)
  311. token, data = self.await_response(1000)
  312. if token == 'PFT:success':
  313. print("File closed")
  314. return True
  315. elif token == 'PFT:ioerror':
  316. print("Client storage device IO error")
  317. return False
  318. elif token == 'PFT:invalid':
  319. print("No open file")
  320. return False
  321. def abort(self):
  322. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.ABORT)
  323. token, data = self.await_response()
  324. if token == 'PFT:success':
  325. print("Transfer Aborted")
  326. def copy(self, filename, dest_filename, compression, dummy):
  327. self.connect()
  328. has_heatshrink = heatshrink_exists and self.compression['algorithm'] == 'heatshrink'
  329. if compression and not has_heatshrink:
  330. hs = '2' if sys.version_info[0] > 2 else ''
  331. print("Compression not supported by client. Use 'pip install heatshrink%s' to fix." % hs)
  332. compression = False
  333. data = open(filename, "rb").read()
  334. filesize = len(data)
  335. self.open(dest_filename, compression, dummy)
  336. block_size = self.protocol.block_size
  337. if compression:
  338. data = heatshrink.encode(data, window_sz2=self.compression['window'], lookahead_sz2=self.compression['lookahead'])
  339. cratio = filesize / len(data)
  340. blocks = math.floor((len(data) + block_size - 1) / block_size)
  341. kibs = 0
  342. dump_pctg = 0
  343. start_time = millis()
  344. for i in range(blocks):
  345. start = block_size * i
  346. end = start + block_size
  347. self.write(data[start:end])
  348. kibs = (( (i+1) * block_size) / 1024) / (millis() + 1 - start_time) * 1000
  349. if (i / blocks) >= dump_pctg:
  350. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors), end='')
  351. dump_pctg += 0.1
  352. if self.protocol.errors > 0:
  353. # Dump last status (errors may not be visible)
  354. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3} - Aborting...".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors), end='')
  355. print("") # New line to break the transfer speed line
  356. self.close()
  357. print("Transfer aborted due to protocol errors")
  358. #raise Exception("Transfer aborted due to protocol errors")
  359. return False
  360. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format(100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors)) # no one likes transfers finishing at 99.8%
  361. if not self.close():
  362. print("Transfer failed")
  363. return False
  364. print("Transfer complete")
  365. return True
  366. class EchoProtocol(object):
  367. def __init__(self, protocol):
  368. protocol.register(['echo:'], self.process_input)
  369. self.protocol = protocol
  370. def process_input(self, data):
  371. print(data)