test_utils.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import itertools
  2. import pytest
  3. from prompt_toolkit.utils import take_using_weights
  4. def test_using_weights():
  5. def take(generator, count):
  6. return list(itertools.islice(generator, 0, count))
  7. # Check distribution.
  8. data = take(take_using_weights(["A", "B", "C"], [5, 10, 20]), 35)
  9. assert data.count("A") == 5
  10. assert data.count("B") == 10
  11. assert data.count("C") == 20
  12. assert data == [
  13. "A",
  14. "B",
  15. "C",
  16. "C",
  17. "B",
  18. "C",
  19. "C",
  20. "A",
  21. "B",
  22. "C",
  23. "C",
  24. "B",
  25. "C",
  26. "C",
  27. "A",
  28. "B",
  29. "C",
  30. "C",
  31. "B",
  32. "C",
  33. "C",
  34. "A",
  35. "B",
  36. "C",
  37. "C",
  38. "B",
  39. "C",
  40. "C",
  41. "A",
  42. "B",
  43. "C",
  44. "C",
  45. "B",
  46. "C",
  47. "C",
  48. ]
  49. # Another order.
  50. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 35)
  51. assert data.count("A") == 20
  52. assert data.count("B") == 10
  53. assert data.count("C") == 5
  54. # Bigger numbers.
  55. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 70)
  56. assert data.count("A") == 40
  57. assert data.count("B") == 20
  58. assert data.count("C") == 10
  59. # Negative numbers.
  60. data = take(take_using_weights(["A", "B", "C"], [-20, 10, 0]), 70)
  61. assert data.count("A") == 0
  62. assert data.count("B") == 70
  63. assert data.count("C") == 0
  64. # All zero-weight items.
  65. with pytest.raises(ValueError):
  66. take(take_using_weights(["A", "B", "C"], [0, 0, 0]), 70)