MarlinBinaryProtocol.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. #
  2. # MarlinBinaryProtocol.py
  3. # Supporting Firmware upload via USB/Serial, saving to the attached media.
  4. #
  5. import serial
  6. import math
  7. import time
  8. from collections import deque
  9. import threading
  10. import sys
  11. import datetime
  12. import random
  13. try:
  14. import heatshrink
  15. heatshrink_exists = True
  16. except ImportError:
  17. heatshrink_exists = False
  18. def millis():
  19. return time.perf_counter() * 1000
  20. class TimeOut(object):
  21. def __init__(self, milliseconds):
  22. self.duration = milliseconds
  23. self.reset()
  24. def reset(self):
  25. self.endtime = millis() + self.duration
  26. def timedout(self):
  27. return millis() > self.endtime
  28. class ReadTimeout(Exception):
  29. pass
  30. class FatalError(Exception):
  31. pass
  32. class SycronisationError(Exception):
  33. pass
  34. class PayloadOverflow(Exception):
  35. pass
  36. class ConnectionLost(Exception):
  37. pass
  38. class Protocol(object):
  39. device = None
  40. baud = None
  41. max_block_size = 0
  42. port = None
  43. block_size = 0
  44. packet_transit = None
  45. packet_status = None
  46. packet_ping = None
  47. errors = 0
  48. packet_buffer = None
  49. simulate_errors = 0
  50. sync = 0
  51. connected = False
  52. syncronised = False
  53. worker_thread = None
  54. response_timeout = 1000
  55. applications = []
  56. responses = deque()
  57. def __init__(self, device, baud, bsize, simerr, timeout):
  58. print("pySerial Version:", serial.VERSION)
  59. self.port = serial.Serial(device, baudrate = baud, write_timeout = 0, timeout = 1)
  60. self.device = device
  61. self.baud = baud
  62. self.block_size = int(bsize)
  63. self.simulate_errors = max(min(simerr, 1.0), 0.0);
  64. self.connected = True
  65. self.response_timeout = timeout
  66. self.register(['ok', 'rs', 'ss', 'fe'], self.process_input)
  67. self.worker_thread = threading.Thread(target=Protocol.receive_worker, args=(self,))
  68. self.worker_thread.start()
  69. def receive_worker(self):
  70. while self.port.in_waiting:
  71. self.port.reset_input_buffer()
  72. def dispatch(data):
  73. for tokens, callback in self.applications:
  74. for token in tokens:
  75. if token == data[:len(token)]:
  76. callback((token, data[len(token):]))
  77. return
  78. def reconnect():
  79. print("Reconnecting..")
  80. self.port.close()
  81. for x in range(10):
  82. try:
  83. if self.connected:
  84. self.port = serial.Serial(self.device, baudrate = self.baud, write_timeout = 0, timeout = 1)
  85. return
  86. else:
  87. print("Connection closed")
  88. return
  89. except:
  90. time.sleep(1)
  91. raise ConnectionLost()
  92. while self.connected:
  93. try:
  94. data = self.port.readline().decode('utf8').rstrip()
  95. if len(data):
  96. #print(data)
  97. dispatch(data)
  98. except OSError:
  99. reconnect()
  100. except UnicodeDecodeError:
  101. # dodgy client output or datastream corruption
  102. self.port.reset_input_buffer()
  103. def shutdown(self):
  104. self.connected = False
  105. self.worker_thread.join()
  106. self.port.close()
  107. def process_input(self, data):
  108. #print(data)
  109. self.responses.append(data)
  110. def register(self, tokens, callback):
  111. self.applications.append((tokens, callback))
  112. def send(self, protocol, packet_type, data = bytearray()):
  113. self.packet_transit = self.build_packet(protocol, packet_type, data)
  114. self.packet_status = 0
  115. self.transmit_attempt = 0
  116. timeout = TimeOut(self.response_timeout * 20)
  117. while self.packet_status == 0:
  118. try:
  119. if timeout.timedout():
  120. raise ConnectionLost()
  121. self.transmit_packet(self.packet_transit)
  122. self.await_response()
  123. except ReadTimeout:
  124. self.errors += 1
  125. #print("Packetloss detected..")
  126. self.packet_transit = None
  127. def await_response(self):
  128. timeout = TimeOut(self.response_timeout)
  129. while not len(self.responses):
  130. time.sleep(0.00001)
  131. if timeout.timedout():
  132. raise ReadTimeout()
  133. while len(self.responses):
  134. token, data = self.responses.popleft()
  135. switch = {'ok' : self.response_ok, 'rs': self.response_resend, 'ss' : self.response_stream_sync, 'fe' : self.response_fatal_error}
  136. switch[token](data)
  137. def send_ascii(self, data, send_and_forget = False):
  138. self.packet_transit = bytearray(data, "utf8") + b'\n'
  139. self.packet_status = 0
  140. self.transmit_attempt = 0
  141. timeout = TimeOut(self.response_timeout * 20)
  142. while self.packet_status == 0:
  143. try:
  144. if timeout.timedout():
  145. return
  146. self.port.write(self.packet_transit)
  147. if send_and_forget:
  148. self.packet_status = 1
  149. else:
  150. self.await_response_ascii()
  151. except ReadTimeout:
  152. self.errors += 1
  153. #print("Packetloss detected..")
  154. except serial.serialutil.SerialException:
  155. return
  156. self.packet_transit = None
  157. def await_response_ascii(self):
  158. timeout = TimeOut(self.response_timeout)
  159. while not len(self.responses):
  160. time.sleep(0.00001)
  161. if timeout.timedout():
  162. raise ReadTimeout()
  163. token, data = self.responses.popleft()
  164. self.packet_status = 1
  165. def corrupt_array(self, data):
  166. rid = random.randint(0, len(data) - 1)
  167. data[rid] ^= 0xAA
  168. return data
  169. def transmit_packet(self, packet):
  170. packet = bytearray(packet)
  171. if(self.simulate_errors > 0 and random.random() > (1.0 - self.simulate_errors)):
  172. if random.random() > 0.9:
  173. #random data drop
  174. start = random.randint(0, len(packet))
  175. end = start + random.randint(1, 10)
  176. packet = packet[:start] + packet[end:]
  177. #print("Dropping {0} bytes".format(end - start))
  178. else:
  179. #random corruption
  180. packet = self.corrupt_array(packet)
  181. #print("Single byte corruption")
  182. self.port.write(packet)
  183. self.transmit_attempt += 1
  184. def build_packet(self, protocol, packet_type, data = bytearray()):
  185. PACKET_TOKEN = 0xB5AD
  186. if len(data) > self.max_block_size:
  187. raise PayloadOverflow()
  188. packet_buffer = bytearray()
  189. packet_buffer += self.pack_int8(self.sync) # 8bit sync id
  190. packet_buffer += self.pack_int4_2(protocol, packet_type) # 4 bit protocol id, 4 bit packet type
  191. packet_buffer += self.pack_int16(len(data)) # 16bit packet length
  192. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer)) # 16bit header checksum
  193. if len(data):
  194. packet_buffer += data
  195. packet_buffer += self.pack_int16(self.build_checksum(packet_buffer))
  196. packet_buffer = self.pack_int16(PACKET_TOKEN) + packet_buffer # 16bit start token, not included in checksum
  197. return packet_buffer
  198. # checksum 16 fletchers
  199. def checksum(self, cs, value):
  200. cs_low = (((cs & 0xFF) + value) % 255);
  201. return ((((cs >> 8) + cs_low) % 255) << 8) | cs_low;
  202. def build_checksum(self, buffer):
  203. cs = 0
  204. for b in buffer:
  205. cs = self.checksum(cs, b)
  206. return cs
  207. def pack_int32(self, value):
  208. return value.to_bytes(4, byteorder='little')
  209. def pack_int16(self, value):
  210. return value.to_bytes(2, byteorder='little')
  211. def pack_int8(self, value):
  212. return value.to_bytes(1, byteorder='little')
  213. def pack_int4_2(self, vh, vl):
  214. value = ((vh & 0xF) << 4) | (vl & 0xF)
  215. return value.to_bytes(1, byteorder='little')
  216. def connect(self):
  217. print("Connecting: Switching Marlin to Binary Protocol...")
  218. self.send_ascii("M28B1")
  219. self.send(0, 1)
  220. def disconnect(self):
  221. self.send(0, 2)
  222. self.syncronised = False
  223. def response_ok(self, data):
  224. try:
  225. packet_id = int(data);
  226. except ValueError:
  227. return
  228. if packet_id != self.sync:
  229. raise SycronisationError()
  230. self.sync = (self.sync + 1) % 256
  231. self.packet_status = 1
  232. def response_resend(self, data):
  233. packet_id = int(data);
  234. self.errors += 1
  235. if not self.syncronised:
  236. print("Retrying syncronisation")
  237. elif packet_id != self.sync:
  238. raise SycronisationError()
  239. def response_stream_sync(self, data):
  240. sync, max_block_size, protocol_version = data.split(',')
  241. self.sync = int(sync)
  242. self.max_block_size = int(max_block_size)
  243. self.block_size = self.max_block_size if self.max_block_size < self.block_size else self.block_size
  244. self.protocol_version = protocol_version
  245. self.packet_status = 1
  246. self.syncronised = True
  247. print("Connection synced [{0}], binary protocol version {1}, {2} byte payload buffer".format(self.sync, self.protocol_version, self.max_block_size))
  248. def response_fatal_error(self, data):
  249. raise FatalError()
  250. class FileTransferProtocol(object):
  251. protocol_id = 1
  252. class Packet(object):
  253. QUERY = 0
  254. OPEN = 1
  255. CLOSE = 2
  256. WRITE = 3
  257. ABORT = 4
  258. responses = deque()
  259. def __init__(self, protocol, timeout = None):
  260. protocol.register(['PFT:success', 'PFT:version:', 'PFT:fail', 'PFT:busy', 'PFT:ioerror', 'PTF:invalid'], self.process_input)
  261. self.protocol = protocol
  262. self.response_timeout = timeout or protocol.response_timeout
  263. def process_input(self, data):
  264. #print(data)
  265. self.responses.append(data)
  266. def await_response(self, timeout = None):
  267. timeout = TimeOut(timeout or self.response_timeout)
  268. while not len(self.responses):
  269. time.sleep(0.0001)
  270. if timeout.timedout():
  271. raise ReadTimeout()
  272. return self.responses.popleft()
  273. def connect(self):
  274. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.QUERY);
  275. token, data = self.await_response()
  276. if token != 'PFT:version:':
  277. return False
  278. self.version, _, compression = data.split(':')
  279. if compression != 'none':
  280. algorithm, window, lookahead = compression.split(',')
  281. self.compression = {'algorithm': algorithm, 'window': int(window), 'lookahead': int(lookahead)}
  282. else:
  283. self.compression = {'algorithm': 'none'}
  284. print("File Transfer version: {0}, compression: {1}".format(self.version, self.compression['algorithm']))
  285. def open(self, filename, compression, dummy):
  286. payload = b'\1' if dummy else b'\0' # dummy transfer
  287. payload += b'\1' if compression else b'\0' # payload compression
  288. payload += bytearray(filename, 'utf8') + b'\0'# target filename + null terminator
  289. timeout = TimeOut(5000)
  290. token = None
  291. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload);
  292. while token != 'PFT:success' and not timeout.timedout():
  293. try:
  294. token, data = self.await_response(1000)
  295. if token == 'PFT:success':
  296. print(filename,"opened")
  297. return
  298. elif token == 'PFT:busy':
  299. print("Broken transfer detected, purging")
  300. self.abort()
  301. time.sleep(0.1)
  302. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.OPEN, payload);
  303. timeout.reset()
  304. elif token == 'PFT:fail':
  305. raise Exception("Can not open file on client")
  306. except ReadTimeout:
  307. pass
  308. raise ReadTimeout()
  309. def write(self, data):
  310. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.WRITE, data);
  311. def close(self):
  312. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.CLOSE);
  313. token, data = self.await_response(1000)
  314. if token == 'PFT:success':
  315. print("File closed")
  316. return True
  317. elif token == 'PFT:ioerror':
  318. print("Client storage device IO error")
  319. return False
  320. elif token == 'PFT:invalid':
  321. print("No open file")
  322. return False
  323. def abort(self):
  324. self.protocol.send(FileTransferProtocol.protocol_id, FileTransferProtocol.Packet.ABORT);
  325. token, data = self.await_response()
  326. if token == 'PFT:success':
  327. print("Transfer Aborted")
  328. def copy(self, filename, dest_filename, compression, dummy):
  329. self.connect()
  330. compression_support = heatshrink_exists and self.compression['algorithm'] == 'heatshrink' and compression
  331. if compression and (not heatshrink_exists or not self.compression['algorithm'] == 'heatshrink'):
  332. print("Compression not supported by client")
  333. #compression_support = False
  334. data = open(filename, "rb").read()
  335. filesize = len(data)
  336. self.open(dest_filename, compression_support, dummy)
  337. block_size = self.protocol.block_size
  338. if compression_support:
  339. data = heatshrink.encode(data, window_sz2=self.compression['window'], lookahead_sz2=self.compression['lookahead'])
  340. cratio = filesize / len(data)
  341. blocks = math.floor((len(data) + block_size - 1) / block_size)
  342. kibs = 0
  343. dump_pctg = 0
  344. start_time = millis()
  345. for i in range(blocks):
  346. start = block_size * i
  347. end = start + block_size
  348. self.write(data[start:end])
  349. kibs = (( (i+1) * block_size) / 1024) / (millis() + 1 - start_time) * 1000
  350. if (i / blocks) >= dump_pctg:
  351. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression_support else "", self.protocol.errors), end='')
  352. dump_pctg += 0.1
  353. if self.protocol.errors > 0:
  354. # Dump last status (errors may not be visible)
  355. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3} - Aborting...".format((i / blocks) * 100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression_support else "", self.protocol.errors), end='')
  356. print("") # New line to break the transfer speed line
  357. self.close()
  358. print("Transfer aborted due to protocol errors")
  359. #raise Exception("Transfer aborted due to protocol errors")
  360. return False;
  361. print("\r{0:2.0f}% {1:4.2f}KiB/s {2} Errors: {3}".format(100, kibs, "[{0:4.2f}KiB/s]".format(kibs * cratio) if compression_support else "", self.protocol.errors)) # no one likes transfers finishing at 99.8%
  362. if not self.close():
  363. print("Transfer failed")
  364. return False
  365. print("Transfer complete")
  366. return True
  367. class EchoProtocol(object):
  368. def __init__(self, protocol):
  369. protocol.register(['echo:'], self.process_input)
  370. self.protocol = protocol
  371. def process_input(self, data):
  372. print(data)