Browse Source

feat(nodestore-bigtable): Add support for zstd (#22876)

* create mock bigtable api for bigtable tests

* add support for zstd to bigtable nodestorage

* disable cache in test to actually test zstd

* remove useless condition

* flag zstd payloads separately

* stricter validation of compression arg
Markus Unterwaditzer 4 years ago
parent
commit
35373b47ee

+ 1 - 0
requirements-base.txt

@@ -64,6 +64,7 @@ ua-parser>=0.10.0,<0.11.0
 unidiff>=0.5.4
 urllib3==1.24.2
 uwsgi>2.0.0,<2.1.0
+zstandard>=0.14.1,<=0.15
 
 # msgpack>=1.0.0 correctly encodes / decodes byte types in python3 as the
 # msgpack bin type. It could NOT do this in python 2. However, 1.0 also drops

+ 55 - 10
src/sentry/nodestore/bigtable/backend.py

@@ -3,7 +3,8 @@ from __future__ import absolute_import, print_function
 import os
 import struct
 from threading import Lock
-from zlib import compress as zlib_compress, decompress as zlib_decompress
+import zstandard
+import zlib
 
 from google.cloud import bigtable
 from google.cloud.bigtable.row_set import RowSet
@@ -32,6 +33,39 @@ _connection_lock = Lock()
 _connection_cache = {}
 
 
+def _compress_data(orig_data, data, compression):
+    flags = 0
+
+    if callable(compression):
+        compression = compression(orig_data)
+
+    if compression == "zstd":
+        flags |= BigtableNodeStorage._FLAG_COMPRESSED_ZSTD
+        cctx = zstandard.ZstdCompressor()
+        data = cctx.compress(data)
+    elif compression is True or compression == "zlib":
+        flags |= BigtableNodeStorage._FLAG_COMPRESSED_ZLIB
+        data = zlib.compress(data)
+    elif compression is False:
+        pass
+    else:
+        raise ValueError("invalid argument for compression: {!r}".format(compression))
+
+    return data, flags
+
+
+def _decompress_data(data, flags):
+    # Check for a compression flag on, if so
+    # decompress the data.
+    if flags & BigtableNodeStorage._FLAG_COMPRESSED_ZLIB:
+        return zlib.decompress(data)
+    elif flags & BigtableNodeStorage._FLAG_COMPRESSED_ZSTD:
+        cctx = zstandard.ZstdDecompressor()
+        return cctx.decompress(data)
+    else:
+        return data
+
+
 def get_connection(project, instance, table, options):
     key = (project, instance, table)
     try:
@@ -56,6 +90,19 @@ class BigtableNodeStorage(NodeStorage):
     """
     A Bigtable-based backend for storing node data.
 
+    :param project: Passed to bigtable client
+    :param instance: Passed to bigtable client
+    :param table: Passed to bigtable client
+    :param automatic_expiry: Whether to set bigtable GC rule.
+    :param default_ttl: How many days keys should be stored (and considered
+        valid for reading + returning)
+    :param compression: A boolean whether to enable zlib-compression, the
+        string "zstd" to use zstd instead, or a callable that takes `data`
+        (event JSON as dict) and returns either of those values.
+
+        Can take a callable so we can opt projects in and out of zstd while we
+        do the migration.
+
     >>> BigtableNodeStorage(
     ...     project='some-project',
     ...     instance='sentry',
@@ -72,7 +119,8 @@ class BigtableNodeStorage(NodeStorage):
     flags_column = b"f"
     data_column = b"0"
 
-    _FLAG_COMPRESSED = 1 << 0
+    _FLAG_COMPRESSED_ZLIB = 1 << 0
+    _FLAG_COMPRESSED_ZSTD = 1 << 1
 
     def __init__(
         self,
@@ -160,11 +208,7 @@ class BigtableNodeStorage(NodeStorage):
         if self.flags_column in columns:
             flags = struct.unpack("B", columns[self.flags_column][0].value)[0]
 
-        # Check for a compression flag on, if so
-        # decompress the data.
-        if flags & self._FLAG_COMPRESSED:
-            data = zlib_decompress(data)
-
+        data = _decompress_data(data, flags)
         return json_loads(data)
 
     def set(self, id, data, ttl=None):
@@ -173,6 +217,7 @@ class BigtableNodeStorage(NodeStorage):
         self._set_cache_item(id, data)
 
     def encode_row(self, id, data, ttl=None):
+        orig_data = data
         data = json_dumps(data).encode("utf-8")
 
         row = self.connection.row(id)
@@ -209,9 +254,9 @@ class BigtableNodeStorage(NodeStorage):
         # This only flag we're tracking now is whether compression
         # is on or not for the data column.
         flags = 0
-        if self.compression:
-            flags |= self._FLAG_COMPRESSED
-            data = zlib_compress(data)
+
+        data, compression_flag = _compress_data(orig_data, data, self.compression)
+        flags |= compression_flag
 
         # Only need to write the column at all if any flags
         # are enabled. And if so, pack it into a single byte.

+ 178 - 104
tests/sentry/nodestore/bigtable/backend/tests.py

@@ -3,110 +3,184 @@ from __future__ import absolute_import
 import pytest
 
 from sentry.nodestore.bigtable.backend import BigtableNodeStorage
-from sentry.testutils import TestCase
+from sentry.utils.cache import memoize
 from sentry.utils.compat import mock
 
 
-@pytest.mark.skip(reason="Bigtable is not available in CI")
-class BigtableNodeStorageTest(TestCase):
-    def setUp(self):
-        self.ns = BigtableNodeStorage(project="test")
-        self.ns.bootstrap()
-
-    def test_get(self):
-        node_id = "node_id"
-        data = {"foo": "bar"}
-        self.ns.set(node_id, data)
-        assert self.ns.get(node_id) == data
-
-    def test_get_multi(self):
-        nodes = [("a" * 32, {"foo": "a"}), ("b" * 32, {"foo": "b"})]
-
-        self.ns.set(nodes[0][0], nodes[0][1])
-        self.ns.set(nodes[1][0], nodes[1][1])
-
-        result = self.ns.get_multi([nodes[0][0], nodes[1][0]])
-        assert result == dict((n[0], n[1]) for n in nodes)
-
-    def test_set(self):
-        node_id = "d2502ebbd7df41ceba8d3275595cac33"
-        data = {"foo": "bar"}
-        self.ns.set(node_id, data)
-        assert self.ns.get(node_id) == data
-
-    def test_delete(self):
-        node_id = "d2502ebbd7df41ceba8d3275595cac33"
-        data = {"foo": "bar"}
-        self.ns.set(node_id, data)
-        assert self.ns.get(node_id) == data
-        self.ns.delete(node_id)
-        assert not self.ns.get(node_id)
-
-    def test_delete_multi(self):
-        nodes = [("node_1", {"foo": "a"}), ("node_2", {"foo": "b"})]
-
-        for n in nodes:
-            self.ns.set(n[0], n[1])
-
-        self.ns.delete_multi([nodes[0][0], nodes[1][0]])
-        assert not self.ns.get(nodes[0][0])
-        assert not self.ns.get(nodes[1][0])
-
-    def test_compression(self):
-        self.ns.compression = True
-        self.test_get()
-
-    def test_cache(self):
-        node_1 = ("a" * 32, {"foo": "a"})
-        node_2 = ("b" * 32, {"foo": "b"})
-        node_3 = ("c" * 32, {"foo": "c"})
-
-        for node_id, data in [node_1, node_2, node_3]:
-            self.ns.set(node_id, data)
-
-        # Get / get multi populates cache
-        assert self.ns.get(node_1[0]) == node_1[1]
-        assert self.ns.get_multi([node_2[0], node_3[0]]) == {
-            node_2[0]: node_2[1],
-            node_3[0]: node_3[1],
-        }
-        with mock.patch.object(self.ns.connection, "read_row") as mock_read_row:
-            assert self.ns.get(node_1[0]) == node_1[1]
-            assert self.ns.get(node_2[0]) == node_2[1]
-            assert self.ns.get(node_3[0]) == node_3[1]
-            assert mock_read_row.call_count == 0
-
-        with mock.patch.object(self.ns.connection, "read_rows") as mock_read_rows:
-            assert self.ns.get_multi([node_1[0], node_2[0], node_3[0]])
-            assert mock_read_rows.call_count == 0
-
-        # Manually deleted item should still retrievable from cache
-        row = self.ns.connection.row(node_1[0])
-        row.delete()
-        row.commit()
-        assert self.ns.get(node_1[0]) == node_1[1]
-        assert self.ns.get_multi([node_1[0], node_2[0]]) == {
-            node_1[0]: node_1[1],
-            node_2[0]: node_2[1],
-        }
-
-        # Deletion clears cache
-        self.ns.delete(node_1[0])
-        assert self.ns.get_multi([node_1[0], node_2[0]]) == {node_1[0]: None, node_2[0]: node_2[1]}
-        self.ns.delete_multi([node_1[0], node_2[0]])
-        assert self.ns.get_multi([node_1[0], node_2[0]]) == {node_1[0]: None, node_2[0]: None}
-
-        # Setting the item updates cache
-        new_value = {"event_id": "d" * 32}
-        self.ns.set(node_1[0], new_value)
-        with mock.patch.object(self.ns.connection, "read_row") as mock_read_row:
-            assert self.ns.get(node_1[0]) == new_value
-            assert mock_read_row.call_count == 0
-
-        # Missing rows are never cached
-        assert self.ns.get("node_4") is None
-        with mock.patch.object(self.ns.connection, "read_row") as mock_read_row:
-            mock_read_row.return_value = None
-            self.ns.get("node_4")
-            self.ns.get("node_4")
-            assert mock_read_row.call_count == 2
+class MockedBigtableNodeStorage(BigtableNodeStorage):
+    class Cell(object):
+        def __init__(self, value, timestamp):
+            self.value = value
+            self.timestamp = timestamp
+
+    class Row(object):
+        def __init__(self, connection, row_key):
+            self.row_key = row_key.encode("utf8")
+            self.connection = connection
+
+        def delete(self):
+            self.connection._table.pop(self.row_key, None)
+
+        def set_cell(self, family, col, value, timestamp):
+            assert family == "x"
+            self.connection._table.setdefault(self.row_key, {})[col] = [
+                MockedBigtableNodeStorage.Cell(value, timestamp)
+            ]
+
+        def commit(self):
+            # commits not implemented, changes are applied immediately
+            pass
+
+        @property
+        def cells(self):
+            return {"x": dict(self.connection._table.get(self.row_key) or ())}
+
+    class Connection(object):
+        def __init__(self):
+            self._table = {}
+
+        def row(self, key):
+            return MockedBigtableNodeStorage.Row(self, key)
+
+        def read_row(self, key):
+            return MockedBigtableNodeStorage.Row(self, key)
+
+        def read_rows(self, row_set):
+            assert not row_set.row_ranges, "unsupported"
+            return [self.read_row(key) for key in row_set.row_keys]
+
+        def mutate_rows(self, rows):
+            # commits not implemented, changes are applied immediately
+            pass
+
+    @memoize
+    def connection(self):
+        return MockedBigtableNodeStorage.Connection()
+
+    def bootstrap(self):
+        pass
+
+
+@pytest.fixture(params=[MockedBigtableNodeStorage, BigtableNodeStorage])
+def ns(request):
+    if request.param is BigtableNodeStorage:
+        pytest.skip("Bigtable is not available in CI")
+
+    ns = request.param(project="test")
+    ns.bootstrap()
+    return ns
+
+
+@pytest.mark.parametrize(
+    "compression,expected_prefix",
+    [(True, (b"\x78\x01", b"\x78\x9c", b"\x78\xda")), (False, b"{"), ("zstd", b"\x28\xb5\x2f\xfd")],
+    ids=["zlib", "ident", "zstd"],
+)
+def test_get(ns, compression, expected_prefix):
+    ns.compression = compression
+    node_id = "node_id"
+    data = {"foo": "bar"}
+    ns.set(node_id, data)
+
+    # Make sure this value does not get used during read. We may have various
+    # forms of compression in bigtable.
+    ns.compression = lambda: 1 / 0
+    # Do not use cache as that entirely bypasses what we want to test here.
+    ns.cache = None
+    assert ns.get(node_id) == data
+
+    raw_data = ns.connection.read_row("node_id").cells["x"][b"0"][0].value
+    assert raw_data.startswith(expected_prefix)
+
+
+def test_get_multi(ns):
+    nodes = [("a" * 32, {"foo": "a"}), ("b" * 32, {"foo": "b"})]
+
+    ns.set(nodes[0][0], nodes[0][1])
+    ns.set(nodes[1][0], nodes[1][1])
+
+    result = ns.get_multi([nodes[0][0], nodes[1][0]])
+    assert result == dict((n[0], n[1]) for n in nodes)
+
+
+def test_set(ns):
+    node_id = "d2502ebbd7df41ceba8d3275595cac33"
+    data = {"foo": "bar"}
+    ns.set(node_id, data)
+    assert ns.get(node_id) == data
+
+
+def test_delete(ns):
+    node_id = "d2502ebbd7df41ceba8d3275595cac33"
+    data = {"foo": "bar"}
+    ns.set(node_id, data)
+    assert ns.get(node_id) == data
+    ns.delete(node_id)
+    assert not ns.get(node_id)
+
+
+def test_delete_multi(ns):
+    nodes = [("node_1", {"foo": "a"}), ("node_2", {"foo": "b"})]
+
+    for n in nodes:
+        ns.set(n[0], n[1])
+
+    ns.delete_multi([nodes[0][0], nodes[1][0]])
+    assert not ns.get(nodes[0][0])
+    assert not ns.get(nodes[1][0])
+
+
+def test_cache(ns):
+    node_1 = ("a" * 32, {"foo": "a"})
+    node_2 = ("b" * 32, {"foo": "b"})
+    node_3 = ("c" * 32, {"foo": "c"})
+
+    for node_id, data in [node_1, node_2, node_3]:
+        ns.set(node_id, data)
+
+    # Get / get multi populates cache
+    assert ns.get(node_1[0]) == node_1[1]
+    assert ns.get_multi([node_2[0], node_3[0]]) == {
+        node_2[0]: node_2[1],
+        node_3[0]: node_3[1],
+    }
+    with mock.patch.object(ns.connection, "read_row") as mock_read_row:
+        assert ns.get(node_1[0]) == node_1[1]
+        assert ns.get(node_2[0]) == node_2[1]
+        assert ns.get(node_3[0]) == node_3[1]
+        assert mock_read_row.call_count == 0
+
+    with mock.patch.object(ns.connection, "read_rows") as mock_read_rows:
+        assert ns.get_multi([node_1[0], node_2[0], node_3[0]])
+        assert mock_read_rows.call_count == 0
+
+    # Manually deleted item should still retrievable from cache
+    row = ns.connection.row(node_1[0])
+    row.delete()
+    row.commit()
+    assert ns.get(node_1[0]) == node_1[1]
+    assert ns.get_multi([node_1[0], node_2[0]]) == {
+        node_1[0]: node_1[1],
+        node_2[0]: node_2[1],
+    }
+
+    # Deletion clears cache
+    ns.delete(node_1[0])
+    assert ns.get_multi([node_1[0], node_2[0]]) == {node_1[0]: None, node_2[0]: node_2[1]}
+    ns.delete_multi([node_1[0], node_2[0]])
+    assert ns.get_multi([node_1[0], node_2[0]]) == {node_1[0]: None, node_2[0]: None}
+
+    # Setting the item updates cache
+    new_value = {"event_id": "d" * 32}
+    ns.set(node_1[0], new_value)
+    with mock.patch.object(ns.connection, "read_row") as mock_read_row:
+        assert ns.get(node_1[0]) == new_value
+        assert mock_read_row.call_count == 0
+
+    # Missing rows are never cached
+    assert ns.get("node_4") is None
+    with mock.patch.object(ns.connection, "read_row") as mock_read_row:
+        mock_read_row.return_value = None
+        ns.get("node_4")
+        ns.get("node_4")
+        assert mock_read_row.call_count == 2