MarlinBinaryProtocol.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #
  2. # MarlinBinaryProtocol.py
  3. # Supporting Firmware upload via USB/Serial, saving to the attached media.
  4. #
  5. import serial
  6. import math
  7. import time
  8. from collections import deque
  9. import threading
  10. import sys
  11. import datetime
  12. import random
  13. try:
  14. import heatshrink2 as heatshrink
  15. heatshrink_exists = True
  16. except ImportError:
  17. try:
  18. import heatshrink
  19. heatshrink_exists = True
  20. except ImportError:
  21. heatshrink_exists = False
  22. def millis():
  23. return time.perf_counter() * 1000
  24. class TimeOut(object):
  25. def __init__(self, milliseconds):
  26. self.duration = milliseconds
  27. self.reset()
  28. def reset(self):
  29. self.endtime = millis() + self.duration
  30. def timedout(self):
  31. return millis() > self.endtime
  32. class ReadTimeout(Exception):
  33. pass
  34. class FatalError(Exception):
  35. pass
  36. class SycronisationError(Exception):
  37. pass
  38. class PayloadOverflow(Exception):
  39. pass
  40. class ConnectionLost(Exception):
  41. pass
  42. class Protocol(object):
  43. device = None
  44. baud = None
  45. max_block_size = 0
  46. port = None
  47. block_size = 0
  48. packet_transit = None
  49. packet_status = None
  50. packet_ping = None
  51. errors = 0
  52. packet_buffer = None
  53. simulate_errors = 0
  54. sync = 0
  55. connected = False
  56. syncronised = False
  57. worker_thread = None
  58. response_timeout = 1000
  59. applications = []
  60. responses = deque()
  61. def __init__(self, device, baud, bsize, simerr, timeout):
  62. print("pySerial Version:", serial.VERSION)
  63. self.port = serial.Serial(device, baudrate = baud, write_timeout = 0, timeout = 1)
  64. self.device = device
  65. self.baud = baud
  66. self.block_size = int(bsize)
  67. self.simulate_errors = max(min(simerr, 1.0), 0.0)
  68. self.connected = True
  69. self.response_timeout = timeout
  70. self.register(['ok', 'rs', 'ss', 'fe'], self.process_input)
  71. self.worker_thread = threading.Thread(target=Protocol.receive_worker, args=(self,))
  72. self.worker_thread.start()
  73. def receive_worker(self):
  74. while self.port.in_waiting:
  75. self.port.reset_input_buffer()
  76. def dispatch(data):
  77. for tokens, callback in self.applications:
  78. for token in tokens:
  79. if token == data[:len(token)]:
  80. callback((token, data[len(token):]))
  81. return
  82. def reconnect():
  83. print("Reconnecting..")
  84. self.port.close()
  85. for x in range(10):
  86. try:
  87. if self.connected:
  88. self.port = serial.Serial(self.device, baudrate = self.baud, write_timeout = 0, timeout = 1)
  89. return
  90. else:
  91. print("Connection closed")
  92. return
  93. except:
  94. time.sleep(1)
  95. raise ConnectionLost()
  96. while self.connected:
  97. try:
  98. data = self.port.readline().decode('utf8').rstrip()
  99. if len(data):
  100. #print(data)
  101. dispatch(data)
  102. except OSError:
  103. reconnect()
  104. except UnicodeDecodeError:
  105. # dodgy client output or datastream corruption
  106. self.port.reset_input_buffer()
  107. def shutdown(self):
  108. self.connected = False
  109. self.worker_thread.join()
  110. self.port.close()
  111. def process_input(self, data):
  112. #print(data)
  113. self.responses.append(data)
  114. def register(self, tokens, callback):
  115. self.applications.append((tokens, callback))
  116. def send(self, protocol, packet_type, data = bytearray()):
  117. self.packet_transit = self.build_packet(protocol, packet_type, data)
  118. self.packet_status = 0
  119. self.transmit_attempt = 0
  120. timeout = TimeOut(self.response_timeout * 20)
  121. while self.packet_status == 0:
  122. try:
  123. if timeout.timedout():
  124. raise ConnectionLost()
  125. self.transmit_packet(self.packet_transit)
  126. self.await_response()
  127. except ReadTimeout:
  128. self.errors += 1
  129. #print("Packetloss detected..")
  130. self.packet_transit = None
  131. def await_response(self):
  132. timeout = TimeOut(self.response_timeout)
  133. while not len(self.responses):
  134. time.sleep(0.00001)
  135. if timeout.timedout():
  136. raise ReadTimeout()
  137. while len(self.responses):
  138. token, data = self.responses.popleft()
  139. switch = {'ok' : self.response_ok, 'rs': self.response_resend, 'ss' : self.response_stream_sync, 'fe' : self.response_fatal_error}
  140. switch[token](data)
  141. def send_ascii(self, data, send_and_forget = False):
  142. self.packet_transit = bytearray(data, "utf8") + b'\n'
  143. self.packet_status = 0
  144. self.transmit_attempt = 0
  145. timeout = TimeOut(self.response_timeout * 20)
  146. while self.packet_status == 0:
  147. try:
  148. if timeout.timedout():
  149. return
  150. self.port.write(self.packet_transit)
  151. if send_and_forget:
  152. self.packet_status = 1
  153. else:
  154. self.await_response_ascii()
  155. except ReadTimeout:
  156. self.errors += 1
  157. #print("Packetloss detected..")
  158. except serial.SerialException:
  159. return
  160. self.packet_transit = None
  161. def await_response_ascii(self):
  162. timeout = TimeOut(self.response_timeout)
  163. while not len(self.responses):
  164. time.sleep(0.00001)
  165. if timeout.timedout():
  166. raise ReadTimeout()
  167. token, data = self.responses.popleft()
  168. self.packet_status = 1
  169. def corrupt_array(self, data):
  170. rid = random.randint(0, len(data) - 1)
  171. data[rid] ^= 0xAA
  172. return data
  173. def transmit_packet(self, packet):
  174. packet = bytearray(packet)
  175. if (self.simulate_errors > 0 and random.random() > (1.0 - self.simulate_errors)):
  176. if random.random() > 0.9:
  177. #random data drop
  178. start = random.randint(0, len(packet))
  179. end = start + random.randint(1, 10)
  180. packet = packet[:start] + packet[end:]
  181. #print("Dropping {0} bytes".format(end - start))
  182. else:
  183. #random corruption
  184. packet = self.corrupt_array(packet)
  185. #print("Single byte corruption")
  186. self.port.write(packet)
  187. self.transmit_attempt += 1
  188. def build_packet(self, protocol, packet_type, data = bytearray()):
  189. PACKET_TOKEN = 0xB5AD
  190. if len(data) > self.max_block_size:
  191. raise PayloadOverflow()
  192. packet_buffer = bytearray()
  193. packet_buffer += self.pack_int8(self.sync) # 8bit sync id
  194. packet_buffer += self.pack_int4_2(protocol, packet_type) # 4 bit protocol id, 4 bit packet type
  195. packet_buffer += self.pack_int16(len(data)) # 16bit packet length
  196. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer)) # 16bit header checksum
  197. if len(data):
  198. packet_buffer += data
  199. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer))
  200. packet_buffer = self.pack_int16(PACKET_TOKEN) + packet_buffer # 16bit start token, not included in checksum
  201. return packet_buffer
  202. # checksum 16 fletchers
  203. def checksum(self, cs, value):
  204. cs_low = (((cs & 0xFF) + value) % 255)
  205. return ((((cs >> 8) + cs_low) % 255) << 8) | cs_low
  206. def build_checksum(self, buffer):
  207. cs = 0
  208. for b in buffer:
  209. cs = self.checksum(cs, b)
  210. return cs
  211. def pack_int32(self, value):
  212. return value.to_bytes(4, byteorder='little')
  213. def pack_int16(self, value):
  214. return value.to_bytes(2, byteorder='little')
  215. def pack_int8(self, value):
  216. return value.to_bytes(1, byteorder='little')
  217. def pack_int4_2(self, vh, vl):
  218. value = ((vh & 0xF) << 4) | (vl & 0xF)
  219. return value.to_bytes(1, byteorder='little')
  220. def connect(self):
  221. print("Connecting: Switching Marlin to Binary Protocol...")
  222. self.send_ascii("M28B1")
  223. self.send(0, 1)
  224. def disconnect(self):
  225. self.send(0, 2)
  226. self.syncronised = False
  227. def response_ok(self, data):
  228. try:
  229. packet_id = int(data)
  230. except ValueError:
  231. return
  232. if packet_id != self.sync:
  233. raise SycronisationError()
  234. self.sync = (self.sync + 1) % 256
  235. self.packet_status = 1
  236. def response_resend(self, data):
  237. packet_id = int(data)
  238. self.errors += 1
  239. if not self.syncronised:
  240. print("Retrying syncronisation")
  241. elif packet_id != self.sync:
  242. raise SycronisationError()
  243. def response_stream_sync(self, data):
  244. sync, max_block_size, protocol_version = data.split(',')
  245. self.sync = int(sync)
  246. self.max_block_size = int(max_block_size)
  247. self.block_size = self.max_block_size if self.max_block_size < self.block_size else self.block_size
  248. self.protocol_version = protocol_version
  249. self.packet_status = 1
  250. self.syncronised = True
  251. print("Connection synced [{0}], binary protocol version {1}, {2} byte payload buffer".format(self.sync, self.protocol_version, self.max_block_size))
  252. def response_fatal_error(self, data):
  253. raise FatalError()
  254. class FileTransferProtocol(object):
  255. protocol_id = 1
  256. class Packet(object):
  257. QUERY = 0
  258. OPEN = 1
  259. CLOSE = 2
  260. WRITE = 3
  261. ABORT = 4
  262. responses = deque()
  263. def __init__(self, protocol, timeout = None):
  264. protocol.register(['PFT:success', 'PFT:version:', 'PFT:fail', 'PFT:busy', 'PFT:ioerror', 'PTF:invalid'], self.process_input)
  265. self.protocol = protocol
  266. self.response_timeout = timeout or protocol.response_timeout
  267. def process_input(self, data):
  268. #print(data)
  269. self.responses.append(data)
  270. def await_response(self, timeout = None):
  271. timeout = TimeOut(timeout or self.response_timeout)
  272. while not len(self.responses):
  273. time.sleep(0.0001)
  274. if timeout.timedout():
  275. raise ReadTimeout()
  276. return self.responses.popleft()
  277. def connect(self):
  278. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.QUERY)
  279. token, data = self.await_response()
  280. if token != 'PFT:version:':
  281. return False
  282. self.version, _, compression = data.split(':')
  283. if compression != 'none':
  284. algorithm, window, lookahead = compression.split(',')
  285. self.compression = {'algorithm': algorithm, 'window': int(window), 'lookahead': int(lookahead)}
  286. else:
  287. self.compression = {'algorithm': 'none'}
  288. print("File Transfer version: {0}, compression: {1}".format(self.version, self.compression['algorithm']))
  289. def open(self, filename, compression, dummy):
  290. payload = b'\1' if dummy else b'\0' # dummy transfer
  291. payload += b'\1' if compression else b'\0' # payload compression
  292. payload += bytearray(filename, 'utf8') + b'\0'# target filename + null terminator
  293. timeout = TimeOut(5000)
  294. token = None
  295. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload)
  296. while token != 'PFT:success' and not timeout.timedout():
  297. try:
  298. token, data = self.await_response(1000)
  299. if token == 'PFT:success':
  300. print(filename,"opened")
  301. return
  302. elif token == 'PFT:busy':
  303. print("Broken transfer detected, purging")
  304. self.abort()
  305. time.sleep(0.1)
  306. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload)
  307. timeout.reset()
  308. elif token == 'PFT:fail':
  309. raise Exception("Can not open file on client")
  310. except ReadTimeout:
  311. pass
  312. raise ReadTimeout()
  313. def write(self, data):
  314. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.WRITE, data)
  315. def close(self):
  316. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.CLOSE)
  317. token, data = self.await_response(1000)
  318. if token == 'PFT:success':
  319. print("File closed")
  320. return True
  321. elif token == 'PFT:ioerror':
  322. print("Client storage device IO error")
  323. return False
  324. elif token == 'PFT:invalid':
  325. print("No open file")
  326. return False
  327. def abort(self):
  328. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.ABORT)
  329. token, data = self.await_response()
  330. if token == 'PFT:success':
  331. print("Transfer Aborted")
  332. def copy(self, filename, dest_filename, compression, dummy):
  333. self.connect()
  334. has_heatshrink = heatshrink_exists and self.compression['algorithm'] == 'heatshrink'
  335. if compression and not has_heatshrink:
  336. hs = '2' if sys.version_info[0] > 2 else ''
  337. print("Compression not supported by client. Use 'pip install heatshrink%s' to fix." % hs)
  338. compression = False
  339. data = open(filename, "rb").read()
  340. filesize = len(data)
  341. self.open(dest_filename, compression, dummy)
  342. block_size = self.protocol.block_size
  343. if compression:
  344. data = heatshrink.encode(data, window_sz2=self.compression['window'], lookahead_sz2=self.compression['lookahead'])
  345. cratio = filesize / len(data)
  346. blocks = math.floor((len(data) + block_size - 1) / block_size)
  347. kibs = 0
  348. dump_pctg = 0
  349. start_time = millis()
  350. for i in range(blocks):
  351. start = block_size * i
  352. end = start + block_size
  353. self.write(data[start:end])
  354. kibs = (( (i+1) * block_size) / 1024) / (millis() + 1 - start_time) * 1000
  355. if (i / blocks) >= dump_pctg:
  356. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors), end='')
  357. dump_pctg += 0.1
  358. if self.protocol.errors > 0:
  359. # Dump last status (errors may not be visible)
  360. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3} - Aborting...".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors), end='')
  361. print("") # New line to break the transfer speed line
  362. self.close()
  363. print("Transfer aborted due to protocol errors")
  364. #raise Exception("Transfer aborted due to protocol errors")
  365. return False
  366. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format(100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression else "", self.protocol.errors)) # no one likes transfers finishing at 99.8%
  367. if not self.close():
  368. print("Transfer failed")
  369. return False
  370. print("Transfer complete")
  371. return True
  372. class EchoProtocol(object):
  373. def __init__(self, protocol):
  374. protocol.register(['echo:'], self.process_input)
  375. self.protocol = protocol
  376. def process_input(self, data):
  377. print(data)