test_func.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import pytest
  2. import multiprocessing
  3. import random
  4. import threading
  5. import time
  6. import six
  7. import library.python.func as func
  8. def test_map0():
  9. assert None is func.map0(lambda x: x + 1, None)
  10. assert 3 == func.map0(lambda x: x + 1, 2)
  11. assert None is func.map0(len, None)
  12. assert 2 == func.map0(len, [1, 2])
  13. def test_single():
  14. assert 1 == func.single([1])
  15. with pytest.raises(Exception):
  16. assert 1 == func.single([])
  17. with pytest.raises(Exception):
  18. assert 1 == func.single([1, 2])
  19. def test_memoize():
  20. class Counter(object):
  21. @staticmethod
  22. def inc():
  23. Counter._qty = getattr(Counter, '_qty', 0) + 1
  24. return Counter._qty
  25. @func.memoize()
  26. def t1(a):
  27. return a, Counter.inc()
  28. @func.memoize()
  29. def t2(a):
  30. return a, Counter.inc()
  31. @func.memoize()
  32. def t3(a):
  33. return a, Counter.inc()
  34. @func.memoize()
  35. def t4(a):
  36. return a, Counter.inc()
  37. @func.memoize()
  38. def t5(a, b, c):
  39. return a + b + c, Counter.inc()
  40. @func.memoize()
  41. def t6():
  42. return Counter.inc()
  43. @func.memoize(limit=2)
  44. def t7(a, _b):
  45. return a, Counter.inc()
  46. assert (1, 1) == t1(1)
  47. assert (1, 1) == t1(1)
  48. assert (2, 2) == t1(2)
  49. assert (2, 2) == t1(2)
  50. assert (1, 3) == t2(1)
  51. assert (1, 3) == t2(1)
  52. assert (2, 4) == t2(2)
  53. assert (2, 4) == t2(2)
  54. assert (1, 5) == t3(1)
  55. assert (1, 5) == t3(1)
  56. assert (2, 6) == t3(2)
  57. assert (2, 6) == t3(2)
  58. assert (1, 7) == t4(1)
  59. assert (1, 7) == t4(1)
  60. assert (2, 8) == t4(2)
  61. assert (2, 8) == t4(2)
  62. assert (6, 9) == t5(1, 2, 3)
  63. assert (6, 9) == t5(1, 2, 3)
  64. assert (7, 10) == t5(1, 2, 4)
  65. assert (7, 10) == t5(1, 2, 4)
  66. assert 11 == t6()
  67. assert 11 == t6()
  68. assert (1, 12) == t7(1, None)
  69. assert (2, 13) == t7(2, None)
  70. assert (1, 12) == t7(1, None)
  71. assert (2, 13) == t7(2, None)
  72. # removed result for (1, None)
  73. assert (3, 14) == t7(3, None)
  74. assert (1, 15) == t7(1, None)
  75. class ClassWithMemoizedMethod(object):
  76. def __init__(self):
  77. self.a = 0
  78. @func.memoize(True)
  79. def t(self, i):
  80. self.a += i
  81. return i
  82. obj = ClassWithMemoizedMethod()
  83. assert 10 == obj.t(10)
  84. assert 10 == obj.a
  85. assert 10 == obj.t(10)
  86. assert 10 == obj.a
  87. assert 20 == obj.t(20)
  88. assert 30 == obj.a
  89. assert 20 == obj.t(20)
  90. assert 30 == obj.a
  91. def test_first():
  92. assert func.first([0, [], (), None, False, {}, 0.0, '1', 0]) == '1'
  93. assert func.first([]) is None
  94. assert func.first([0]) is None
  95. def test_split():
  96. assert func.split([1, 1], lambda x: x) == ([1, 1], [])
  97. assert func.split([0, 0], lambda x: x) == ([], [0, 0])
  98. assert func.split([], lambda x: x) == ([], [])
  99. assert func.split([1, 0, 1], lambda x: x) == ([1, 1], [0])
  100. def test_flatten_dict():
  101. assert func.flatten_dict({"a": 1, "b": 2}) == {"a": 1, "b": 2}
  102. assert func.flatten_dict({"a": 1}) == {"a": 1}
  103. assert func.flatten_dict({}) == {}
  104. assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}) == {"a": 1, "b.c.d": 2}
  105. assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}, separator="/") == {"a": 1, "b/c/d": 2}
  106. def test_threadsafe_singleton():
  107. class ShouldBeSingle(six.with_metaclass(func.Singleton, object)):
  108. def __new__(cls, *args, **kwargs):
  109. time.sleep(0.1)
  110. return super(ShouldBeSingle, cls).__new__(cls, *args, **kwargs)
  111. threads_count = 100
  112. threads = [None] * threads_count
  113. results = [None] * threads_count
  114. def class_factory(results, i):
  115. time.sleep(0.1)
  116. results[i] = ShouldBeSingle()
  117. for i in range(threads_count):
  118. threads[i] = threading.Thread(target=class_factory, args=(results, i))
  119. for i in range(threads_count):
  120. threads[i].start()
  121. for i in range(threads_count):
  122. threads[i].join()
  123. assert len(set(results)) == 1
  124. def test_memoize_thread_local():
  125. class Counter(object):
  126. def __init__(self, s):
  127. self.val = s
  128. def inc(self):
  129. self.val += 1
  130. return self.val
  131. @func.memoize(thread_local=True)
  132. def get_counter(start):
  133. return Counter(start)
  134. def th_inc():
  135. assert get_counter(0).inc() == 1
  136. assert get_counter(0).inc() == 2
  137. assert get_counter(10).inc() == 11
  138. assert get_counter(10).inc() == 12
  139. th_inc()
  140. th = threading.Thread(target=th_inc)
  141. th.start()
  142. th.join()
  143. def test_memoize_not_thread_safe():
  144. class Counter(object):
  145. def __init__(self, s):
  146. self.val = s
  147. def inc(self):
  148. self.val += 1
  149. return self.val
  150. @func.memoize(thread_safe=False)
  151. def io_job(n):
  152. time.sleep(0.1)
  153. return Counter(n)
  154. def worker(n):
  155. assert io_job(n).inc() == n + 1
  156. assert io_job(n).inc() == n + 2
  157. assert io_job(n*10).inc() == n*10 + 1
  158. assert io_job(n*10).inc() == n*10 + 2
  159. assert io_job(n).inc() == n + 3
  160. threads = []
  161. for i in range(5):
  162. threads.append(threading.Thread(target=worker, args=(i+1,)))
  163. st = time.time()
  164. for thread in threads:
  165. thread.start()
  166. for thread in threads:
  167. thread.join()
  168. elapsed_time = time.time() - st
  169. assert elapsed_time < 0.5
  170. def test_memoize_not_thread_safe_concurrent():
  171. class Counter(object):
  172. def __init__(self, s):
  173. self.val = s
  174. def inc(self):
  175. self.val += 1
  176. return self.val
  177. @func.memoize(thread_safe=False)
  178. def io_job(n):
  179. time.sleep(0.1)
  180. return Counter(n)
  181. def worker():
  182. io_job(100).inc()
  183. th1 = threading.Thread(target=worker)
  184. th2 = threading.Thread(target=worker)
  185. th3 = threading.Thread(target=worker)
  186. th1.start()
  187. time.sleep(0.05)
  188. th2.start()
  189. th1.join()
  190. assert io_job(100).inc() == 100 + 2
  191. th3.start()
  192. # th3 instantly got counter from memory
  193. assert io_job(100).inc() == 100 + 4
  194. th2.join()
  195. # th2 shoud increase th1 counter
  196. assert io_job(100).inc() == 100 + 6
  197. def test_memoize_not_thread_safe_stress():
  198. @func.memoize(thread_safe=False)
  199. def job():
  200. for _ in range(1000):
  201. hash = random.getrandbits(128)
  202. return hash
  203. def worker(n):
  204. hash = job()
  205. results[n] = hash
  206. num_threads = min(multiprocessing.cpu_count()*4, 64)
  207. threads = []
  208. results = [None for _ in range(num_threads)]
  209. for i in range(num_threads):
  210. thread = threading.Thread(target=worker, args=(i,))
  211. threads.append(thread)
  212. for thread in threads:
  213. thread.start()
  214. for thread in threads:
  215. thread.join()
  216. assert len(set(results)) == 1
  217. def test_memoize_thread_safe():
  218. class Counter(object):
  219. def __init__(self, s):
  220. self.val = s
  221. def inc(self):
  222. self.val += 1
  223. return self.val
  224. @func.memoize(thread_safe=True)
  225. def io_job(n):
  226. time.sleep(0.05)
  227. return Counter(n)
  228. def worker(n):
  229. assert io_job(n).inc() == n + 1
  230. assert io_job(n).inc() == n + 2
  231. assert io_job(n*10).inc() == n*10 + 1
  232. assert io_job(n*10).inc() == n*10 + 2
  233. threads = []
  234. for i in range(5):
  235. threads.append(threading.Thread(target=worker, args=(i+1,)))
  236. st = time.time()
  237. for thread in threads:
  238. thread.start()
  239. for thread in threads:
  240. thread.join()
  241. elapsed_time = time.time() - st
  242. assert elapsed_time >= 0.5
  243. if __name__ == '__main__':
  244. pytest.main([__file__])