test_utils.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import itertools
  3. import pytest
  4. from prompt_toolkit.utils import take_using_weights
  5. def test_using_weights():
  6. def take(generator, count):
  7. return list(itertools.islice(generator, 0, count))
  8. # Check distribution.
  9. data = take(take_using_weights(["A", "B", "C"], [5, 10, 20]), 35)
  10. assert data.count("A") == 5
  11. assert data.count("B") == 10
  12. assert data.count("C") == 20
  13. assert data == [
  14. "A",
  15. "B",
  16. "C",
  17. "C",
  18. "B",
  19. "C",
  20. "C",
  21. "A",
  22. "B",
  23. "C",
  24. "C",
  25. "B",
  26. "C",
  27. "C",
  28. "A",
  29. "B",
  30. "C",
  31. "C",
  32. "B",
  33. "C",
  34. "C",
  35. "A",
  36. "B",
  37. "C",
  38. "C",
  39. "B",
  40. "C",
  41. "C",
  42. "A",
  43. "B",
  44. "C",
  45. "C",
  46. "B",
  47. "C",
  48. "C",
  49. ]
  50. # Another order.
  51. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 35)
  52. assert data.count("A") == 20
  53. assert data.count("B") == 10
  54. assert data.count("C") == 5
  55. # Bigger numbers.
  56. data = take(take_using_weights(["A", "B", "C"], [20, 10, 5]), 70)
  57. assert data.count("A") == 40
  58. assert data.count("B") == 20
  59. assert data.count("C") == 10
  60. # Negative numbers.
  61. data = take(take_using_weights(["A", "B", "C"], [-20, 10, 0]), 70)
  62. assert data.count("A") == 0
  63. assert data.count("B") == 70
  64. assert data.count("C") == 0
  65. # All zero-weight items.
  66. with pytest.raises(ValueError):
  67. take(take_using_weights(["A", "B", "C"], [0, 0, 0]), 70)