Browse Source

Extend support importlib.resources in Arcadia
5d391b8f31717f85fcd88c4ae5ce3b57a723e42e

shadchin 10 months ago
parent
commit
65c7668fad

+ 37 - 25
library/python/runtime_py3/sitecustomize.pyx

@@ -22,10 +22,35 @@ ResourceReader.register(__res._ResfsResourceReader)
 METADATA_NAME = re.compile("^Name: (.*)$", re.MULTILINE)
 
 
-class ArcadiaResource(Traversable):
-    def __init__(self, resfs_key):
-        self.resfs_key = resfs_key
+class ArcadiaTraversable(Traversable):
+    def __init__(self, resfs):
+        self._resfs = resfs
+        self._path = pathlib.Path(resfs)
 
+    def __eq__(self, other) -> bool:
+        if isinstance(other, ArcadiaTraversable):
+            return self._path == other._path
+        raise NotImplementedError
+
+    def __lt__(self, other) -> bool:
+        if isinstance(other, ArcadiaTraversable):
+            return self._path < other._path
+        raise NotImplementedError
+
+    def __hash__(self) -> int:
+        return hash(self._path)
+
+    @property
+    def name(self):
+        return self._path.name
+
+    @property
+    def suffix(self):
+        return self._path.suffix
+
+
+
+class ArcadiaResource(ArcadiaTraversable):
     def is_file(self):
         return True
 
@@ -33,9 +58,9 @@ class ArcadiaResource(Traversable):
         return False
 
     def open(self, mode="r", *args, **kwargs):
-        data = __res.find(self.resfs_key.encode("utf-8"))
+        data = __res.find(self._resfs.encode("utf-8"))
         if data is None:
-            raise FileNotFoundError(self.resfs_key)
+            raise FileNotFoundError(self._resfs)
 
         stream = io.BytesIO(data)
 
@@ -50,18 +75,11 @@ class ArcadiaResource(Traversable):
     def iterdir(self):
         return iter(())
 
-    @property
-    def name(self):
-        return os.path.basename(self.resfs_key)
+    def __repr__(self) -> str:
+        return f"ArcadiaResource({self._resfs!r})"
 
-    def __repr__(self):
-        return f"ArcadiaResource({self.resfs_key!r})"
-
-
-class ArcadiaResourceContainer(Traversable):
-    def __init__(self, prefix):
-        self.resfs_prefix = prefix
 
+class ArcadiaResourceContainer(ArcadiaTraversable):
     def is_dir(self):
         return True
 
@@ -70,19 +88,17 @@ class ArcadiaResourceContainer(Traversable):
 
     def iterdir(self):
         seen = set()
-        for key, path_without_prefix in __res.iter_keys(
-            self.resfs_prefix.encode("utf-8")
-        ):
+        for key, path_without_prefix in __res.iter_keys(self._resfs.encode("utf-8")):
             if b"/" in path_without_prefix:
                 subdir = path_without_prefix.split(b"/", maxsplit=1)[0].decode("utf-8")
                 if subdir not in seen:
                     seen.add(subdir)
-                    yield ArcadiaResourceContainer(f"{self.resfs_prefix}{subdir}/")
+                    yield ArcadiaResourceContainer(f"{self._resfs}{subdir}/")
             else:
                 yield ArcadiaResource(key.decode("utf-8"))
 
     def open(self, *args, **kwargs):
-        raise IsADirectoryError(self.resfs_prefix)
+        raise IsADirectoryError(self._resfs)
 
     @staticmethod
     def _flatten(compound_names):
@@ -104,12 +120,8 @@ class ArcadiaResourceContainer(Traversable):
 
         raise FileNotFoundError("/".join(self._flatten(descendants)))
 
-    @property
-    def name(self):
-        return os.path.basename(self.resfs_prefix[:-1])
-
     def __repr__(self):
-        return f"ArcadiaResourceContainer({self.resfs_prefix!r})"
+        return f"ArcadiaResourceContainer({self._resfs!r})"
 
 
 class ArcadiaDistribution(Distribution):

+ 11 - 0
library/python/runtime_py3/test/test_resources.py

@@ -111,3 +111,14 @@ def test_files_read_text(package, resource, expected):
 )
 def test_files_iterdir(package, expected):
     assert tuple(resource.name for resource in ir.files(package).iterdir()) == expected
+
+
+@pytest.mark.parametrize(
+    "package, expected",
+    (
+        ("resources", ("foo.txt", "submodule")),
+        ("resources.submodule", ("bar.txt",)),
+    ),
+)
+def test_files_iterdir_with_sort(package, expected):
+    assert tuple(resource.name for resource in sorted(ir.files(package).iterdir())) == expected