test_evaluation_context.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from dataclasses import dataclass
  2. from typing import Any
  3. import pytest
  4. from flagpole.evaluation_context import ContextBuilder, EvaluationContext, EvaluationContextDict
  5. class TestEvaluationContext:
  6. # Identity fields tests are mainly upholding that our hashing strategy does
  7. # not change in the future, and that we calculate the id using the correct
  8. # context values and keys in order.
  9. def test_adds_identity_fields(self):
  10. eval_context = EvaluationContext({}, set())
  11. assert eval_context.id == 1245845410931227995499360226027473197403882391305
  12. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo"})
  13. expected_id = 484477975355580460928302712356218993825269143262
  14. assert eval_context.id == expected_id
  15. # Assert that we skip the missing field but still generate the same
  16. # context ID.
  17. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo", "whoops"})
  18. assert eval_context.id == expected_id
  19. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo", "baz"})
  20. expected_id = 1249805218608667754842212156585681631068251083301
  21. assert eval_context.id == expected_id
  22. # Assert that we use all properties to generate the context when all
  23. # identity fields are missing.
  24. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"whoops", "test"})
  25. assert eval_context.id == expected_id
  26. def test_no_identity_fields_included(self):
  27. eval_context = EvaluationContext({})
  28. assert eval_context.id == 1245845410931227995499360226027473197403882391305
  29. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"})
  30. expected_id = 1249805218608667754842212156585681631068251083301
  31. assert eval_context.id == expected_id
  32. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo", "test": "property"})
  33. expected_id = 1395427532315258482176540981434194664973697472186
  34. assert eval_context.id == expected_id
  35. def test_get_has_data(self):
  36. eval_context = EvaluationContext({"foo": "bar", "baz": "barfoo"}, {"foo"})
  37. assert eval_context.has("foo") is True
  38. assert eval_context.get("foo") == "bar"
  39. assert eval_context.has("baz") is True
  40. assert eval_context.get("baz") == "barfoo"
  41. assert eval_context.has("bar") is False
  42. assert eval_context.get("bar") is None
  43. @dataclass
  44. class ContextData:
  45. foo: str | None = None
  46. baz: int | None = None
  47. buzz: dict | set | None = None
  48. class TestContextBuilder:
  49. def test_empty_context_builder(self):
  50. context_builder = ContextBuilder[ContextData]()
  51. context = context_builder.build()
  52. assert context.size() == 0
  53. def test_static_transformer(self):
  54. def static_transformer(_data: ContextData) -> dict[str, Any]:
  55. return dict(foo="bar", baz=1)
  56. eval_context = (
  57. ContextBuilder[ContextData]()
  58. .add_context_transformer(static_transformer)
  59. .build(ContextData())
  60. )
  61. assert eval_context.size() == 2
  62. assert eval_context.get("foo") == "bar"
  63. assert eval_context.get("baz") == 1
  64. def test_transformer_with_data(self):
  65. def transformer_with_data(data: ContextData) -> dict[str, Any]:
  66. return dict(foo="bar", baz=getattr(data, "baz", None))
  67. eval_context = (
  68. ContextBuilder[ContextData]()
  69. .add_context_transformer(transformer_with_data)
  70. .build(ContextData(baz=2))
  71. )
  72. assert eval_context.size() == 2
  73. assert eval_context.get("foo") == "bar"
  74. assert eval_context.get("baz") == 2
  75. def test_multiple_context_transformers(self):
  76. def transformer_one(data: ContextData) -> dict[str, Any]:
  77. return dict(foo="overwrite_me", baz=2, buzz=getattr(data, "buzz"))
  78. def transformer_two(_data: ContextData) -> dict[str, Any]:
  79. return dict(foo="bar")
  80. eval_context = (
  81. ContextBuilder[ContextData]()
  82. .add_context_transformer(transformer_one)
  83. .add_context_transformer(transformer_two)
  84. .build(ContextData(foo="bar", buzz={1, 2, 3}))
  85. )
  86. assert eval_context.size() == 3
  87. assert eval_context.get("foo") == "bar"
  88. assert eval_context.get("baz") == 2
  89. assert eval_context.get("buzz") == {1, 2, 3}
  90. def test_with_exception_handler(self):
  91. exc_message = "oh noooooo"
  92. def broken_transformer(_data: ContextData) -> EvaluationContextDict:
  93. raise Exception(exc_message)
  94. context_builder = ContextBuilder[ContextData]().add_context_transformer(broken_transformer)
  95. with pytest.raises(Exception) as exc:
  96. context_builder.build(ContextData())
  97. assert exc.match(exc_message)
  98. # Ensure builder doesn't raise an exception
  99. context_builder.add_exception_handler(lambda _exc: None)
  100. context_builder.build(ContextData())
  101. with pytest.raises(Exception):
  102. context_builder.add_exception_handler(lambda _exc: None)
  103. # This is nearly identical to the evaluation context around identity fields,
  104. # just to ensure we compile and pass the correct list
  105. def test_identity_fields_passing(self):
  106. def transformer_with_data(_data: ContextData) -> dict[str, Any]:
  107. return dict(foo="bar", baz="barfoo")
  108. eval_context = ContextBuilder[ContextData]().build(ContextData(baz=2))
  109. # This should be empty dictionary, empty identity fields list
  110. assert eval_context.id == 1245845410931227995499360226027473197403882391305
  111. eval_context = (
  112. ContextBuilder[ContextData]()
  113. .add_context_transformer(transformer_with_data, ["foo"])
  114. .build(ContextData(baz=2))
  115. )
  116. expected_context_id = 484477975355580460928302712356218993825269143262
  117. assert eval_context.id == expected_context_id
  118. # The full identity_fields list passed into the context should be
  119. # ["foo", "baz", "whoops"], but "whoops" will be filtered out by the
  120. # context since the field does not exist in the context dict.
  121. eval_context = (
  122. ContextBuilder[ContextData]()
  123. .add_context_transformer(transformer_with_data, ["foo"])
  124. .add_context_transformer(transformer_with_data, ["baz", "whoops"])
  125. .build(ContextData(baz=2))
  126. )
  127. expected_context_id = 1249805218608667754842212156585681631068251083301
  128. assert eval_context.id == expected_context_id