stk500v2.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. """
  2. STK500v2 protocol implementation for programming AVR chips.
  3. The STK500v2 protocol is used by the ArduinoMega2560 and a few other Arduino platforms to load firmware.
  4. This is a python 3 conversion of the code created by David Braam for the Cura project.
  5. """
  6. import struct
  7. import sys
  8. import time
  9. from serial import Serial
  10. from serial import SerialException
  11. from serial import SerialTimeoutException
  12. from UM.Logger import Logger
  13. from . import ispBase, intelHex
  14. class Stk500v2(ispBase.IspBase):
  15. def __init__(self):
  16. self.serial = None
  17. self.seq = 1
  18. self.last_addr = -1
  19. self.progress_callback = None
  20. def connect(self, port = "COM22", speed = 115200):
  21. if self.serial is not None:
  22. self.close()
  23. try:
  24. self.serial = Serial(str(port), speed, timeout=1, writeTimeout=10000)
  25. except SerialException:
  26. raise ispBase.IspError("Failed to open serial port")
  27. except:
  28. raise ispBase.IspError("Unexpected error while connecting to serial port:" + port + ":" + str(sys.exc_info()[0]))
  29. self.seq = 1
  30. #Reset the controller
  31. for n in range(0, 2):
  32. self.serial.setDTR(True)
  33. time.sleep(0.1)
  34. self.serial.setDTR(False)
  35. time.sleep(0.1)
  36. time.sleep(0.2)
  37. self.serial.flushInput()
  38. self.serial.flushOutput()
  39. try:
  40. if self.sendMessage([0x10, 0xc8, 0x64, 0x19, 0x20, 0x00, 0x53, 0x03, 0xac, 0x53, 0x00, 0x00]) != [0x10, 0x00]:
  41. raise ispBase.IspError("Failed to enter programming mode")
  42. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  43. if self.sendMessage([0xEE])[1] == 0x00:
  44. self._has_checksum = True
  45. else:
  46. self._has_checksum = False
  47. except ispBase.IspError:
  48. self.close()
  49. raise
  50. self.serial.timeout = 5
  51. def close(self):
  52. if self.serial is not None:
  53. self.serial.close()
  54. self.serial = None
  55. #Leave ISP does not reset the serial port, only resets the device, and returns the serial port after disconnecting it from the programming interface.
  56. # This allows you to use the serial port without opening it again.
  57. def leaveISP(self):
  58. if self.serial is not None:
  59. if self.sendMessage([0x11]) != [0x11, 0x00]:
  60. raise ispBase.IspError("Failed to leave programming mode")
  61. ret = self.serial
  62. self.serial = None
  63. return ret
  64. return None
  65. def isConnected(self):
  66. return self.serial is not None
  67. def hasChecksumFunction(self):
  68. return self._has_checksum
  69. def sendISP(self, data):
  70. recv = self.sendMessage([0x1D, 4, 4, 0, data[0], data[1], data[2], data[3]])
  71. return recv[2:6]
  72. def writeFlash(self, flash_data):
  73. #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
  74. page_size = self.chip["pageSize"] * 2
  75. flash_size = page_size * self.chip["pageCount"]
  76. Logger.log("d", "Writing flash")
  77. if flash_size > 0xFFFF:
  78. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  79. else:
  80. self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
  81. load_count = (len(flash_data) + page_size - 1) / page_size
  82. for i in range(0, int(load_count)):
  83. self.sendMessage([0x13, page_size >> 8, page_size & 0xFF, 0xc1, 0x0a, 0x40, 0x4c, 0x20, 0x00, 0x00] + flash_data[(i * page_size):(i * page_size + page_size)])
  84. if self.progress_callback is not None:
  85. if self._has_checksum:
  86. self.progress_callback(i + 1, load_count)
  87. else:
  88. self.progress_callback(i + 1, load_count * 2)
  89. def verifyFlash(self, flash_data):
  90. if self._has_checksum:
  91. self.sendMessage([0x06, 0x00, (len(flash_data) >> 17) & 0xFF, (len(flash_data) >> 9) & 0xFF, (len(flash_data) >> 1) & 0xFF])
  92. res = self.sendMessage([0xEE])
  93. checksum_recv = res[2] | (res[3] << 8)
  94. checksum = 0
  95. for d in flash_data:
  96. checksum += d
  97. checksum &= 0xFFFF
  98. if hex(checksum) != hex(checksum_recv):
  99. raise ispBase.IspError("Verify checksum mismatch: 0x%x != 0x%x" % (checksum & 0xFFFF, checksum_recv))
  100. else:
  101. #Set load addr to 0, in case we have more then 64k flash we need to enable the address extension
  102. flash_size = self.chip["pageSize"] * 2 * self.chip["pageCount"]
  103. if flash_size > 0xFFFF:
  104. self.sendMessage([0x06, 0x80, 0x00, 0x00, 0x00])
  105. else:
  106. self.sendMessage([0x06, 0x00, 0x00, 0x00, 0x00])
  107. load_count = (len(flash_data) + 0xFF) / 0x100
  108. for i in range(0, int(load_count)):
  109. recv = self.sendMessage([0x14, 0x01, 0x00, 0x20])[2:0x102]
  110. if self.progress_callback is not None:
  111. self.progress_callback(load_count + i + 1, load_count * 2)
  112. for j in range(0, 0x100):
  113. if i * 0x100 + j < len(flash_data) and flash_data[i * 0x100 + j] != recv[j]:
  114. raise ispBase.IspError("Verify error at: 0x%x" % (i * 0x100 + j))
  115. def sendMessage(self, data):
  116. message = struct.pack(">BBHB", 0x1B, self.seq, len(data), 0x0E)
  117. for c in data:
  118. message += struct.pack(">B", c)
  119. checksum = 0
  120. for c in message:
  121. checksum ^= c
  122. message += struct.pack(">B", checksum)
  123. try:
  124. self.serial.write(message)
  125. self.serial.flush()
  126. except SerialTimeoutException:
  127. raise ispBase.IspError("Serial send timeout")
  128. self.seq = (self.seq + 1) & 0xFF
  129. return self.recvMessage()
  130. def recvMessage(self):
  131. state = "Start"
  132. checksum = 0
  133. while True:
  134. s = self.serial.read()
  135. if len(s) < 1:
  136. raise ispBase.IspError("Timeout")
  137. b = struct.unpack(">B", s)[0]
  138. checksum ^= b
  139. if state == "Start":
  140. if b == 0x1B:
  141. state = "GetSeq"
  142. checksum = 0x1B
  143. elif state == "GetSeq":
  144. state = "MsgSize1"
  145. elif state == "MsgSize1":
  146. msg_size = b << 8
  147. state = "MsgSize2"
  148. elif state == "MsgSize2":
  149. msg_size |= b
  150. state = "Token"
  151. elif state == "Token":
  152. if b != 0x0E:
  153. state = "Start"
  154. else:
  155. state = "Data"
  156. data = []
  157. elif state == "Data":
  158. data.append(b)
  159. if len(data) == msg_size:
  160. state = "Checksum"
  161. elif state == "Checksum":
  162. if checksum != 0:
  163. state = "Start"
  164. else:
  165. return data
  166. def portList():
  167. ret = []
  168. import _winreg
  169. key=_winreg.OpenKey(_winreg.HKEY_LOCAL_MACHINE,"HARDWARE\\DEVICEMAP\\SERIALCOMM") #@UndefinedVariable
  170. i=0
  171. while True:
  172. try:
  173. values = _winreg.EnumValue(key, i) #@UndefinedVariable
  174. except:
  175. return ret
  176. if "USBSER" in values[0]:
  177. ret.append(values[1])
  178. i+=1
  179. return ret
  180. def runProgrammer(port, filename):
  181. """ Run an STK500v2 program on serial port 'port' and write 'filename' into flash. """
  182. programmer = Stk500v2()
  183. programmer.connect(port = port)
  184. programmer.programChip(intelHex.readHex(filename))
  185. programmer.close()
  186. def main():
  187. """ Entry point to call the stk500v2 programmer from the commandline. """
  188. import threading
  189. if sys.argv[1] == "AUTO":
  190. Logger.log("d", "portList(): ", repr(portList()))
  191. for port in portList():
  192. threading.Thread(target=runProgrammer, args=(port,sys.argv[2])).start()
  193. time.sleep(5)
  194. else:
  195. programmer = Stk500v2()
  196. programmer.connect(port = sys.argv[1])
  197. programmer.programChip(intelHex.readHex(sys.argv[2]))
  198. sys.exit(1)
  199. if __name__ == "__main__":
  200. main()