|
@@ -1,5 +1,8 @@
|
|
|
import pytest
|
|
|
+import multiprocessing
|
|
|
+import random
|
|
|
import threading
|
|
|
+import time
|
|
|
|
|
|
import library.python.func as func
|
|
|
|
|
@@ -158,5 +161,139 @@ def test_memoize_thread_local():
|
|
|
th.join()
|
|
|
|
|
|
|
|
|
+def test_memoize_not_thread_safe():
|
|
|
+ class Counter(object):
|
|
|
+ def __init__(self, s):
|
|
|
+ self.val = s
|
|
|
+
|
|
|
+ def inc(self):
|
|
|
+ self.val += 1
|
|
|
+ return self.val
|
|
|
+
|
|
|
+ @func.memoize(thread_safe=False)
|
|
|
+ def io_job(n):
|
|
|
+ time.sleep(0.1)
|
|
|
+ return Counter(n)
|
|
|
+
|
|
|
+ def worker(n):
|
|
|
+ assert io_job(n).inc() == n + 1
|
|
|
+ assert io_job(n).inc() == n + 2
|
|
|
+ assert io_job(n*10).inc() == n*10 + 1
|
|
|
+ assert io_job(n*10).inc() == n*10 + 2
|
|
|
+ assert io_job(n).inc() == n + 3
|
|
|
+
|
|
|
+ threads = []
|
|
|
+ for i in range(5):
|
|
|
+ threads.append(threading.Thread(target=worker, args=(i+1,)))
|
|
|
+
|
|
|
+ st = time.time()
|
|
|
+
|
|
|
+ for thread in threads:
|
|
|
+ thread.start()
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+
|
|
|
+ elapsed_time = time.time() - st
|
|
|
+ assert elapsed_time < 0.5
|
|
|
+
|
|
|
+
|
|
|
+def test_memoize_not_thread_safe_concurrent():
|
|
|
+ class Counter(object):
|
|
|
+ def __init__(self, s):
|
|
|
+ self.val = s
|
|
|
+
|
|
|
+ def inc(self):
|
|
|
+ self.val += 1
|
|
|
+ return self.val
|
|
|
+
|
|
|
+ @func.memoize(thread_safe=False)
|
|
|
+ def io_job(n):
|
|
|
+ time.sleep(0.1)
|
|
|
+ return Counter(n)
|
|
|
+
|
|
|
+ def worker():
|
|
|
+ io_job(100).inc()
|
|
|
+
|
|
|
+ th1 = threading.Thread(target=worker)
|
|
|
+ th2 = threading.Thread(target=worker)
|
|
|
+ th3 = threading.Thread(target=worker)
|
|
|
+
|
|
|
+ th1.start()
|
|
|
+ time.sleep(0.05)
|
|
|
+ th2.start()
|
|
|
+
|
|
|
+ th1.join()
|
|
|
+ assert io_job(100).inc() == 100 + 2
|
|
|
+
|
|
|
+ th3.start()
|
|
|
+ # th3 instantly got counter from memory
|
|
|
+ assert io_job(100).inc() == 100 + 4
|
|
|
+
|
|
|
+ th2.join()
|
|
|
+ # th2 shoud increase th1 counter
|
|
|
+ assert io_job(100).inc() == 100 + 6
|
|
|
+
|
|
|
+
|
|
|
+def test_memoize_not_thread_safe_stress():
|
|
|
+ @func.memoize(thread_safe=False)
|
|
|
+ def job():
|
|
|
+ for _ in range(1000):
|
|
|
+ hash = random.getrandbits(128)
|
|
|
+ return hash
|
|
|
+
|
|
|
+ def worker(n):
|
|
|
+ hash = job()
|
|
|
+ results[n] = hash
|
|
|
+
|
|
|
+ num_threads = min(multiprocessing.cpu_count()*4, 64)
|
|
|
+ threads = []
|
|
|
+ results = [None for _ in range(num_threads)]
|
|
|
+
|
|
|
+ for i in range(num_threads):
|
|
|
+ thread = threading.Thread(target=worker, args=(i,))
|
|
|
+ threads.append(thread)
|
|
|
+
|
|
|
+ for thread in threads:
|
|
|
+ thread.start()
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+ assert len(set(results)) == 1
|
|
|
+
|
|
|
+
|
|
|
+def test_memoize_thread_safe():
|
|
|
+ class Counter(object):
|
|
|
+ def __init__(self, s):
|
|
|
+ self.val = s
|
|
|
+
|
|
|
+ def inc(self):
|
|
|
+ self.val += 1
|
|
|
+ return self.val
|
|
|
+
|
|
|
+ @func.memoize(thread_safe=True)
|
|
|
+ def io_job(n):
|
|
|
+ time.sleep(0.05)
|
|
|
+ return Counter(n)
|
|
|
+
|
|
|
+ def worker(n):
|
|
|
+ assert io_job(n).inc() == n + 1
|
|
|
+ assert io_job(n).inc() == n + 2
|
|
|
+ assert io_job(n*10).inc() == n*10 + 1
|
|
|
+ assert io_job(n*10).inc() == n*10 + 2
|
|
|
+
|
|
|
+ threads = []
|
|
|
+ for i in range(5):
|
|
|
+ threads.append(threading.Thread(target=worker, args=(i+1,)))
|
|
|
+
|
|
|
+ st = time.time()
|
|
|
+
|
|
|
+ for thread in threads:
|
|
|
+ thread.start()
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+
|
|
|
+ elapsed_time = time.time() - st
|
|
|
+ assert elapsed_time >= 0.5
|
|
|
+
|
|
|
+
|
|
|
if __name__ == '__main__':
|
|
|
pytest.main([__file__])
|