123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326 |
- import pytest
- import multiprocessing
- import random
- import threading
- import time
- import six
- import library.python.func as func
- def test_map0():
- assert None is func.map0(lambda x: x + 1, None)
- assert 3 == func.map0(lambda x: x + 1, 2)
- assert None is func.map0(len, None)
- assert 2 == func.map0(len, [1, 2])
- def test_single():
- assert 1 == func.single([1])
- with pytest.raises(Exception):
- assert 1 == func.single([])
- with pytest.raises(Exception):
- assert 1 == func.single([1, 2])
- def test_memoize():
- class Counter(object):
- @staticmethod
- def inc():
- Counter._qty = getattr(Counter, '_qty', 0) + 1
- return Counter._qty
- @func.memoize()
- def t1(a):
- return a, Counter.inc()
- @func.memoize()
- def t2(a):
- return a, Counter.inc()
- @func.memoize()
- def t3(a):
- return a, Counter.inc()
- @func.memoize()
- def t4(a):
- return a, Counter.inc()
- @func.memoize()
- def t5(a, b, c):
- return a + b + c, Counter.inc()
- @func.memoize()
- def t6():
- return Counter.inc()
- @func.memoize(limit=2)
- def t7(a, _b):
- return a, Counter.inc()
- assert (1, 1) == t1(1)
- assert (1, 1) == t1(1)
- assert (2, 2) == t1(2)
- assert (2, 2) == t1(2)
- assert (1, 3) == t2(1)
- assert (1, 3) == t2(1)
- assert (2, 4) == t2(2)
- assert (2, 4) == t2(2)
- assert (1, 5) == t3(1)
- assert (1, 5) == t3(1)
- assert (2, 6) == t3(2)
- assert (2, 6) == t3(2)
- assert (1, 7) == t4(1)
- assert (1, 7) == t4(1)
- assert (2, 8) == t4(2)
- assert (2, 8) == t4(2)
- assert (6, 9) == t5(1, 2, 3)
- assert (6, 9) == t5(1, 2, 3)
- assert (7, 10) == t5(1, 2, 4)
- assert (7, 10) == t5(1, 2, 4)
- assert 11 == t6()
- assert 11 == t6()
- assert (1, 12) == t7(1, None)
- assert (2, 13) == t7(2, None)
- assert (1, 12) == t7(1, None)
- assert (2, 13) == t7(2, None)
- # removed result for (1, None)
- assert (3, 14) == t7(3, None)
- assert (1, 15) == t7(1, None)
- class ClassWithMemoizedMethod(object):
- def __init__(self):
- self.a = 0
- @func.memoize(True)
- def t(self, i):
- self.a += i
- return i
- obj = ClassWithMemoizedMethod()
- assert 10 == obj.t(10)
- assert 10 == obj.a
- assert 10 == obj.t(10)
- assert 10 == obj.a
- assert 20 == obj.t(20)
- assert 30 == obj.a
- assert 20 == obj.t(20)
- assert 30 == obj.a
- def test_first():
- assert func.first([0, [], (), None, False, {}, 0.0, '1', 0]) == '1'
- assert func.first([]) is None
- assert func.first([0]) is None
- def test_split():
- assert func.split([1, 1], lambda x: x) == ([1, 1], [])
- assert func.split([0, 0], lambda x: x) == ([], [0, 0])
- assert func.split([], lambda x: x) == ([], [])
- assert func.split([1, 0, 1], lambda x: x) == ([1, 1], [0])
- def test_flatten_dict():
- assert func.flatten_dict({"a": 1, "b": 2}) == {"a": 1, "b": 2}
- assert func.flatten_dict({"a": 1}) == {"a": 1}
- assert func.flatten_dict({}) == {}
- assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}) == {"a": 1, "b.c.d": 2}
- assert func.flatten_dict({"a": 1, "b": {"c": {"d": 2}}}, separator="/") == {"a": 1, "b/c/d": 2}
- def test_threadsafe_singleton():
- class ShouldBeSingle(six.with_metaclass(func.Singleton, object)):
- def __new__(cls, *args, **kwargs):
- time.sleep(0.1)
- return super(ShouldBeSingle, cls).__new__(cls, *args, **kwargs)
- threads_count = 100
- threads = [None] * threads_count
- results = [None] * threads_count
- def class_factory(results, i):
- time.sleep(0.1)
- results[i] = ShouldBeSingle()
- for i in range(threads_count):
- threads[i] = threading.Thread(target=class_factory, args=(results, i))
- for i in range(threads_count):
- threads[i].start()
- for i in range(threads_count):
- threads[i].join()
- assert len(set(results)) == 1
- def test_memoize_thread_local():
- class Counter(object):
- def __init__(self, s):
- self.val = s
- def inc(self):
- self.val += 1
- return self.val
- @func.memoize(thread_local=True)
- def get_counter(start):
- return Counter(start)
- def th_inc():
- assert get_counter(0).inc() == 1
- assert get_counter(0).inc() == 2
- assert get_counter(10).inc() == 11
- assert get_counter(10).inc() == 12
- th_inc()
- th = threading.Thread(target=th_inc)
- th.start()
- th.join()
- def test_memoize_not_thread_safe():
- class Counter(object):
- def __init__(self, s):
- self.val = s
- def inc(self):
- self.val += 1
- return self.val
- @func.memoize(thread_safe=False)
- def io_job(n):
- time.sleep(0.1)
- return Counter(n)
- def worker(n):
- assert io_job(n).inc() == n + 1
- assert io_job(n).inc() == n + 2
- assert io_job(n*10).inc() == n*10 + 1
- assert io_job(n*10).inc() == n*10 + 2
- assert io_job(n).inc() == n + 3
- threads = []
- for i in range(5):
- threads.append(threading.Thread(target=worker, args=(i+1,)))
- st = time.time()
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- elapsed_time = time.time() - st
- assert elapsed_time < 0.5
- def test_memoize_not_thread_safe_concurrent():
- class Counter(object):
- def __init__(self, s):
- self.val = s
- def inc(self):
- self.val += 1
- return self.val
- @func.memoize(thread_safe=False)
- def io_job(n):
- time.sleep(0.1)
- return Counter(n)
- def worker():
- io_job(100).inc()
- th1 = threading.Thread(target=worker)
- th2 = threading.Thread(target=worker)
- th3 = threading.Thread(target=worker)
- th1.start()
- time.sleep(0.05)
- th2.start()
- th1.join()
- assert io_job(100).inc() == 100 + 2
- th3.start()
- # th3 instantly got counter from memory
- assert io_job(100).inc() == 100 + 4
- th2.join()
- # th2 shoud increase th1 counter
- assert io_job(100).inc() == 100 + 6
- def test_memoize_not_thread_safe_stress():
- @func.memoize(thread_safe=False)
- def job():
- for _ in range(1000):
- hash = random.getrandbits(128)
- return hash
- def worker(n):
- hash = job()
- results[n] = hash
- num_threads = min(multiprocessing.cpu_count()*4, 64)
- threads = []
- results = [None for _ in range(num_threads)]
- for i in range(num_threads):
- thread = threading.Thread(target=worker, args=(i,))
- threads.append(thread)
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- assert len(set(results)) == 1
- def test_memoize_thread_safe():
- class Counter(object):
- def __init__(self, s):
- self.val = s
- def inc(self):
- self.val += 1
- return self.val
- @func.memoize(thread_safe=True)
- def io_job(n):
- time.sleep(0.05)
- return Counter(n)
- def worker(n):
- assert io_job(n).inc() == n + 1
- assert io_job(n).inc() == n + 2
- assert io_job(n*10).inc() == n*10 + 1
- assert io_job(n*10).inc() == n*10 + 2
- threads = []
- for i in range(5):
- threads.append(threading.Thread(target=worker, args=(i+1,)))
- st = time.time()
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- elapsed_time = time.time() - st
- assert elapsed_time >= 0.5
- if __name__ == '__main__':
- pytest.main([__file__])
|