Browse Source

fix(metrics-indexer): Merge results properly (#33424)

MeredithAnya 2 years ago
parent
commit
43efb11bf0

+ 20 - 8
src/sentry/sentry_metrics/indexer/postgres_v2.py

@@ -94,7 +94,7 @@ class KeyResults:
         for key_result in key_results:
             self.results[key_result.org_id].update({key_result.string: key_result.id})
 
-    def get_mapped_results(self) -> MutableMapping[int, MutableMapping[str, int]]:
+    def get_mapped_results(self) -> Mapping[int, Mapping[str, int]]:
         """
         Only return results that have org_ids with string/int mappings.
         """
@@ -135,6 +135,16 @@ class KeyResults:
         return cache_key_results
 
 
+def merge_results(
+    result_mappings: Sequence[Mapping[int, Mapping[str, int]]],
+) -> Mapping[int, Mapping[str, int]]:
+    new_results: MutableMapping[int, MutableMapping[str, int]] = defaultdict(dict)
+    for result_map in result_mappings:
+        for org_id, strings in result_map.items():
+            new_results[org_id].update(strings)
+    return new_results
+
+
 class PGStringIndexerV2(Service):
     """
     Provides integer IDs for metric names, tag keys and tag values
@@ -155,7 +165,7 @@ class PGStringIndexerV2(Service):
 
     def bulk_record(
         self, org_strings: MutableMapping[int, Set[str]]
-    ) -> MutableMapping[int, MutableMapping[str, int]]:
+    ) -> Mapping[int, Mapping[str, int]]:
         """
         Takes in a mapping with org_ids to sets of strings.
 
@@ -207,11 +217,11 @@ class PGStringIndexerV2(Service):
             [KeyResult.from_string(k, v) for k, v in cache_results.items() if v is not None]
         )
 
-        mapped_results = cache_key_results.get_mapped_results()
+        mapped_cache_results = cache_key_results.get_mapped_results()
         db_read_keys = cache_key_results.get_unmapped_keys(cache_keys)
 
         if db_read_keys.size == 0:
-            return mapped_results
+            return mapped_cache_results
 
         db_read_key_results = KeyResults()
         db_read_key_results.add_key_results(
@@ -222,7 +232,7 @@ class PGStringIndexerV2(Service):
         )
         new_results_to_cache = db_read_key_results.get_mapped_key_strings_to_ints()
 
-        mapped_results.update(db_read_key_results.get_mapped_results())
+        mapped_db_read_results = db_read_key_results.get_mapped_results()
         db_write_keys = db_read_key_results.get_unmapped_keys(db_read_keys)
 
         metrics.incr(
@@ -238,7 +248,7 @@ class PGStringIndexerV2(Service):
 
         if db_write_keys.size == 0:
             indexer_cache.set_many(new_results_to_cache)
-            return mapped_results
+            return merge_results([mapped_cache_results, mapped_db_read_results])
 
         new_records = []
         for write_pair in db_write_keys.as_tuples():
@@ -264,9 +274,11 @@ class PGStringIndexerV2(Service):
         new_results_to_cache.update(db_write_key_results.get_mapped_key_strings_to_ints())
         indexer_cache.set_many(new_results_to_cache)
 
-        mapped_results.update(db_write_key_results.get_mapped_results())
+        mapped_db_write_results = db_write_key_results.get_mapped_results()
 
-        return mapped_results
+        return merge_results(
+            [mapped_cache_results, mapped_db_read_results, mapped_db_write_results]
+        )
 
     def record(self, org_id: int, string: str) -> int:
         """Store a string and return the integer ID generated for it"""

+ 64 - 0
tests/sentry/sentry_metrics/test_postgres_indexer.py

@@ -84,6 +84,14 @@ class PostgresIndexerV2Test(TestCase):
         # we should have no results for org_id 999
         assert not results.get(999)
 
+    def test_resolve_and_reverse_resolve(self) -> None:
+        """
+        Test `resolve` and `reverse_resolve` methods
+        """
+        org1_id = self.organization.id
+        org_strings = {org1_id: self.strings}
+        PGStringIndexerV2().bulk_record(org_strings=org_strings)
+
         # test resolve and reverse_resolve
         obj = StringIndexer.objects.get(string="hello")
         assert PGStringIndexerV2().resolve(org1_id, "hello") == obj.id
@@ -97,6 +105,62 @@ class PostgresIndexerV2Test(TestCase):
         assert PGStringIndexerV2().resolve(org1_id, "beep") is None
         assert PGStringIndexerV2().reverse_resolve(1234) is None
 
+    def test_already_created_plus_written_results(self) -> None:
+        """
+        Test that we correctly combine db read results with db write results
+        for the same organization.
+        """
+        org_id = 1234
+        v0 = StringIndexer.objects.create(organization_id=org_id, string="v1.2.0")
+        v1 = StringIndexer.objects.create(organization_id=org_id, string="v1.2.1")
+        v2 = StringIndexer.objects.create(organization_id=org_id, string="v1.2.2")
+
+        expected_mapping = {"v1.2.0": v0.id, "v1.2.1": v1.id, "v1.2.2": v2.id}
+
+        results = PGStringIndexerV2().bulk_record(
+            org_strings={org_id: {"v1.2.0", "v1.2.1", "v1.2.2"}}
+        )
+        assert len(results[org_id]) == len(expected_mapping) == 3
+
+        for string, id in results[org_id].items():
+            assert expected_mapping[string] == id
+
+        results = PGStringIndexerV2().bulk_record(
+            org_strings={org_id: {"v1.2.0", "v1.2.1", "v1.2.2", "v1.2.3"}}
+        )
+
+        v3 = StringIndexer.objects.get(organization_id=org_id, string="v1.2.3")
+        expected_mapping["v1.2.3"] = v3.id
+
+        assert len(results[org_id]) == len(expected_mapping) == 4
+
+        for string, id in results[org_id].items():
+            assert expected_mapping[string] == id
+
+    def test_already_cached_plus_read_results(self) -> None:
+        """
+        Test that we correctly combine cached results with read results
+        for the same organization.
+        """
+        org_id = 8
+        cached = {f"{org_id}:beep": 10, f"{org_id}:boop": 11}
+        indexer_cache.set_many(cached)
+
+        results = PGStringIndexerV2().bulk_record(org_strings={org_id: {"beep", "boop"}})
+        assert len(results[org_id]) == 2
+        assert results[org_id]["beep"] == 10
+        assert results[org_id]["boop"] == 11
+
+        # confirm we did not write to the db if results were already cached
+        assert not StringIndexer.objects.filter(organization_id=org_id, string__in=["beep", "boop"])
+
+        bam = StringIndexer.objects.create(organization_id=org_id, string="bam")
+        results = PGStringIndexerV2().bulk_record(org_strings={org_id: {"beep", "boop", "bam"}})
+        assert len(results[org_id]) == 3
+        assert results[org_id]["beep"] == 10
+        assert results[org_id]["boop"] == 11
+        assert results[org_id]["bam"] == bam.id
+
     def test_get_db_records(self):
         """
         Make sure that calling `_get_db_records` doesn't populate the cache