Просмотр исходного кода

Fix: memoize multithreding optimization

Fix: memoize multithreding optimization
l4m3r 1 год назад
Родитель
Сommit
a8c9782fb7
3 измененных файлов с 160 добавлено и 7 удалено
  1. 21 7
      library/python/func/__init__.py
  2. 137 0
      library/python/func/ut/test_func.py
  3. 2 0
      library/python/func/ya.make

+ 21 - 7
library/python/func/__init__.py

@@ -1,6 +1,8 @@
 import functools
 import threading
 import collections
+import contextlib
+import six
 
 
 def map0(func, value):
@@ -76,20 +78,32 @@ class lazy_classproperty(object):
         return getattr(owner, attr_name)
 
 
-def memoize(limit=0, thread_local=False):
+class nullcontext(object):
+    def __enter__(self):
+        pass
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        pass
+
+
+def memoize(limit=0, thread_local=False, thread_safe=True):
     assert limit >= 0
+    assert limit <= 0 or thread_safe, 'memoize() it not thread safe enough to work in limiting and non-thread safe mode'
 
     def decorator(func):
         memory = {}
-        lock = threading.Lock()
+
+        if six.PY3:
+            lock = contextlib.nullcontext()
+        else:
+            lock = nullcontext()
+        lock = threading.Lock() if thread_safe else lock
 
         if limit:
             keys = collections.deque()
 
             def get(args):
-                try:
-                    return memory[args]
-                except KeyError:
+                if args not in memory:
                     with lock:
                         if args not in memory:
                             fargs = args[-1]
@@ -97,7 +111,7 @@ def memoize(limit=0, thread_local=False):
                             keys.append(args)
                             if len(keys) > limit:
                                 del memory[keys.popleft()]
-                        return memory[args]
+                return memory[args]
 
         else:
 
@@ -106,7 +120,7 @@ def memoize(limit=0, thread_local=False):
                     with lock:
                         if args not in memory:
                             fargs = args[-1]
-                            memory[args] = func(*fargs)
+                            memory.setdefault(args, func(*fargs))
                 return memory[args]
 
         if thread_local:

+ 137 - 0
library/python/func/ut/test_func.py

@@ -1,5 +1,8 @@
 import pytest
+import multiprocessing
+import random
 import threading
+import time
 
 import library.python.func as func
 
@@ -158,5 +161,139 @@ def test_memoize_thread_local():
     th.join()
 
 
+def test_memoize_not_thread_safe():
+    class Counter(object):
+        def __init__(self, s):
+            self.val = s
+
+        def inc(self):
+            self.val += 1
+            return self.val
+
+    @func.memoize(thread_safe=False)
+    def io_job(n):
+        time.sleep(0.1)
+        return Counter(n)
+
+    def worker(n):
+        assert io_job(n).inc() == n + 1
+        assert io_job(n).inc() == n + 2
+        assert io_job(n*10).inc() == n*10 + 1
+        assert io_job(n*10).inc() == n*10 + 2
+        assert io_job(n).inc() == n + 3
+
+    threads = []
+    for i in range(5):
+        threads.append(threading.Thread(target=worker, args=(i+1,)))
+
+    st = time.time()
+
+    for thread in threads:
+        thread.start()
+    for thread in threads:
+        thread.join()
+
+    elapsed_time = time.time() - st
+    assert elapsed_time < 0.5
+
+
+def test_memoize_not_thread_safe_concurrent():
+    class Counter(object):
+        def __init__(self, s):
+            self.val = s
+
+        def inc(self):
+            self.val += 1
+            return self.val
+
+    @func.memoize(thread_safe=False)
+    def io_job(n):
+        time.sleep(0.1)
+        return Counter(n)
+
+    def worker():
+        io_job(100).inc()
+
+    th1 = threading.Thread(target=worker)
+    th2 = threading.Thread(target=worker)
+    th3 = threading.Thread(target=worker)
+
+    th1.start()
+    time.sleep(0.05)
+    th2.start()
+
+    th1.join()
+    assert io_job(100).inc() == 100 + 2
+
+    th3.start()
+    # th3 instantly got counter from memory
+    assert io_job(100).inc() == 100 + 4
+
+    th2.join()
+    # th2 shoud increase th1 counter
+    assert io_job(100).inc() == 100 + 6
+
+
+def test_memoize_not_thread_safe_stress():
+    @func.memoize(thread_safe=False)
+    def job():
+        for _ in range(1000):
+            hash = random.getrandbits(128)
+        return hash
+
+    def worker(n):
+        hash = job()
+        results[n] = hash
+
+    num_threads = min(multiprocessing.cpu_count()*4, 64)
+    threads = []
+    results = [None for _ in range(num_threads)]
+
+    for i in range(num_threads):
+        thread = threading.Thread(target=worker, args=(i,))
+        threads.append(thread)
+
+    for thread in threads:
+        thread.start()
+    for thread in threads:
+        thread.join()
+    assert len(set(results)) == 1
+
+
+def test_memoize_thread_safe():
+    class Counter(object):
+        def __init__(self, s):
+            self.val = s
+
+        def inc(self):
+            self.val += 1
+            return self.val
+
+    @func.memoize(thread_safe=True)
+    def io_job(n):
+        time.sleep(0.05)
+        return Counter(n)
+
+    def worker(n):
+        assert io_job(n).inc() == n + 1
+        assert io_job(n).inc() == n + 2
+        assert io_job(n*10).inc() == n*10 + 1
+        assert io_job(n*10).inc() == n*10 + 2
+
+    threads = []
+    for i in range(5):
+        threads.append(threading.Thread(target=worker, args=(i+1,)))
+
+    st = time.time()
+
+    for thread in threads:
+        thread.start()
+    for thread in threads:
+        thread.join()
+
+    elapsed_time = time.time() - st
+    assert elapsed_time >= 0.5
+
+
 if __name__ == '__main__':
     pytest.main([__file__])

+ 2 - 0
library/python/func/ya.make

@@ -2,6 +2,8 @@ PY23_LIBRARY()
 
 PY_SRCS(__init__.py)
 
+PEERDIR(contrib/python/six)
+
 END()
 
 RECURSE_FOR_TESTS(