stk500v2.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. """
  2. STK500v2 protocol implementation for programming AVR chips.
  3. The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
  4. This is a python 3 conversion of the code created by David Braam for the Cura project.
  5. """
  6. import struct
  7. import sys
  8. import time
  9. from serial import Serial # type: ignore
  10. from serial import SerialException
  11. from serial import SerialTimeoutException
  12. from UM.Logger import Logger
  13. from . import ispBase, intelHex
  14. class Stk500v2(ispBase.IspBase):
  15. def __init__(self):
  16. self.serial = None
  17. self.seq = 1
  18. self.last_addr = -1
  19. self.progress_callback = None
  20. def connect(self, port = "COM22", speed = 115200):
  21. if self.serial is not None:
  22. self.close()
  23. try:
  24. self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000)
  25. except SerialException:
  26. raise ispBase.IspError("Failed to open serial port")
  27. except:
  28. raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
  29. self.seq = 1
  30. #Reset the controller
  31. for n in range(0, 2):
  32. self.serial.setDTR(True)
  33. time.sleep(0.1)
  34. self.serial.setDTR(False)
  35. time.sleep(0.1)
  36. time.sleep(0.2)
  37. self.serial.flushInput()
  38. self.serial.flushOutput()
  39. try:
  40. if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
  41. raise ispBase.IspError("Failed to enter programming mode")
  42. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  43. if self.sendMessage([0xEE])[1] == 0x00:
  44. self._has_checksum = True
  45. else:
  46. self._has_checksum = False
  47. except ispBase.IspError:
  48. self.close()
  49. raise
  50. self.serial.timeout = 5
  51. def close(self):
  52. if self.serial is not None:
  53. self.serial.close()
  54. self.serial = None
  55. #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
  56. # This allows you to use the serial port without opening it again.
  57. def leaveISP(self):
  58. if self.serial is not None:
  59. if self.sendMessage([0x11]) != [0x11, 0x00]:
  60. raise ispBase.IspError("Failed to leave programming mode")
  61. ret = self.serial
  62. self.serial = None
  63. return ret
  64. return None
  65. def isConnected(self):
  66. return self.serial is not None
  67. def hasChecksumFunction(self):
  68. return self._has_checksum
  69. def sendISP(self, data):
  70. recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
  71. return recv[2:6]
  72. def writeFlash(self, flash_data):
  73. #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
  74. page_size = self.chip["pageSize"] * 2
  75. flash_size = page_size * self.chip["pageCount"]
  76. Logger.log("d", "Writing flash")
  77. if flash_size > 0xFFFF:
  78. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  79. else:
  80. self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
  81. load_count = (len(flash_data) + page_size - 1) / page_size
  82. for i in range(0, int(load_count)):
  83. self.sendMessage([0x13, page_size >> 8, page_size & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flash_data[(i * page_size):(i * page_size + page_size)])
  84. if self.progress_callback is not None:
  85. if self._has_checksum:
  86. self.progress_callback(i + 1, load_count)
  87. else:
  88. self.progress_callback(i + 1, load_count * 2)
  89. def verifyFlash(self, flash_data):
  90. if self._has_checksum:
  91. self.sendMessage([0x06, 0x00, (len(flash_data) >> 17) & 0xFF, (len(flash_data) >> 9) & 0xFF, (len(flash_data) >> 1) & 0xFF])
  92. res = self.sendMessage([0xEE])
  93. checksum_recv = res[2] | (res[3] << 8)
  94. checksum = 0
  95. for d in flash_data:
  96. checksum += d
  97. checksum &= 0xFFFF
  98. if hex(checksum) != hex(checksum_recv):
  99. raise ispBase.IspError("Verify checksum mismatch: 0x%x != 0x%x" % (checksum & 0xFFFF, checksum_recv))
  100. else:
  101. #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
  102. flash_size = self.chip["pageSize"] * 2 * self.chip["pageCount"]
  103. if flash_size > 0xFFFF:
  104. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  105. else:
  106. self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
  107. load_count = (len(flash_data) + 0xFF) / 0x100
  108. for i in range(0, int(load_count)):
  109. recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
  110. if self.progress_callback is not None:
  111. self.progress_callback(load_count + i + 1, load_count * 2)
  112. for j in range(0, 0x100):
  113. if i * 0x100 + j < len(flash_data) and flash_data[i * 0x100 + j] != recv[j]:
  114. raise ispBase.IspError("Verify error at: 0x%x" % (i * 0x100 + j))
  115. def sendMessage(self, data):
  116. message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
  117. for c in data:
  118. message += struct.pack(">B", c)
  119. checksum = 0
  120. for c in message:
  121. checksum ^= c
  122. message += struct.pack(">B", checksum)
  123. try:
  124. self.serial.write(message)
  125. self.serial.flush()
  126. except SerialTimeoutException:
  127. raise ispBase.IspError("Serial send timeout")
  128. self.seq = (self.seq + 1) & 0xFF
  129. return self.recvMessage()
  130. def recvMessage(self):
  131. state = "Start"
  132. checksum = 0
  133. while True:
  134. s = self.serial.read()
  135. if len(s) < 1:
  136. raise ispBase.IspError("Timeout")
  137. b = struct.unpack(">B", s)[0]
  138. checksum ^= b
  139. if state == "Start":
  140. if b == 0x1B:
  141. state = "GetSeq"
  142. checksum = 0x1B
  143. elif state == "GetSeq":
  144. state = "MsgSize1"
  145. elif state == "MsgSize1":
  146. msg_size = b << 8
  147. state = "MsgSize2"
  148. elif state == "MsgSize2":
  149. msg_size |= b
  150. state = "Token"
  151. elif state == "Token":
  152. if b != 0x0E:
  153. state = "Start"
  154. else:
  155. state = "Data"
  156. data = []
  157. elif state == "Data":
  158. data.append(b)
  159. if len(data) == msg_size:
  160. state = "Checksum"
  161. elif state == "Checksum":
  162. if checksum != 0:
  163. state = "Start"
  164. else:
  165. return data
  166. def portList():
  167. ret = []
  168. import _winreg # type: ignore
  169. key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM") #@UndefinedVariable
  170. i=0
  171. while True:
  172. try:
  173. values = _winreg.EnumValue(key, i) #@UndefinedVariable
  174. except:
  175. return ret
  176. if "USBSER" in values[0]:
  177. ret.append(values[1])
  178. i+=1
  179. return ret
  180. def runProgrammer(port, filename):
  181. """ Run an STK500v2 program on serial port 'port' and write 'filename' into flash. """
  182. programmer = Stk500v2()
  183. programmer.connect(port = port)
  184. programmer.programChip(intelHex.readHex(filename))
  185. programmer.close()
  186. def main():
  187. """ Entry point to call the stk500v2 programmer from the commandline. """
  188. import threading
  189. if sys.argv[1] == "AUTO":
  190. Logger.log("d", "portList(): ", repr(portList()))
  191. for port in portList():
  192. threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
  193. time.sleep(5)
  194. else:
  195. programmer = Stk500v2()
  196. programmer.connect(port = sys.argv[1])
  197. programmer.programChip(intelHex.readHex(sys.argv[2]))
  198. sys.exit(1)
  199. if __name__ == "__main__":
  200. main()