Browse Source

Fix TEnumerator::TIterator
082782faf0e6e05d3aa1c56e096d33ea69282f57

pavook 7 months ago
parent
commit
c4fb2cdf41
2 changed files with 28 additions and 8 deletions
  1. 15 7
      library/cpp/iterator/enumerate.h
  2. 13 1
      library/cpp/iterator/ut/functools_ut.cpp

+ 15 - 7
library/cpp/iterator/enumerate.h

@@ -27,23 +27,30 @@ namespace NPrivate {
         struct TIterator {
             using difference_type = std::ptrdiff_t;
             using value_type = TValue;
-            using pointer = TValue*;
-            using reference = TValue&;
+            using pointer = void;
+            using reference = value_type;
             using iterator_category = std::input_iterator_tag;
 
-            TValue operator*() {
+            reference operator*() const {
                 return {Index_, *Iterator_};
             }
-            TValue operator*() const {
-                return {Index_, *Iterator_};
-            }
-            void operator++() {
+
+            TIterator& operator++() {
                 ++Index_;
                 ++Iterator_;
+                return *this;
+            }
+
+            TIterator operator++(int) {
+                TIterator result = *this;
+                ++(*this);
+                return result;
             }
+
             bool operator!=(const TSentinel& other) const {
                 return Iterator_ != other.Iterator_;
             }
+
             bool operator==(const TSentinel& other) const {
                 return Iterator_ == other.Iterator_;
             }
@@ -51,6 +58,7 @@ namespace NPrivate {
             std::size_t Index_;
             TIteratorState Iterator_;
         };
+
     public:
         using iterator = TIterator;
         using const_iterator = TIterator;

+ 13 - 1
library/cpp/iterator/ut/functools_ut.cpp

@@ -99,7 +99,7 @@ using namespace NFuncTools;
 
     struct TTestSentinel {};
     struct TTestIterator {
-        int operator*() {
+        int operator*() const {
             return X;
         }
         void operator++() {
@@ -139,6 +139,18 @@ using namespace NFuncTools;
             EXPECT_EQ(j, v.size());
         }
 
+        // Test correctness of iterator traits.
+        auto enumerated = Enumerate(a);
+        static_assert(std::ranges::input_range<decltype(enumerated)>);
+        static_assert(
+            std::is_same_v<decltype(enumerated.begin())::pointer,
+            std::iterator_traits<decltype(enumerated.begin())>::pointer>);
+
+        // Post-increment test.
+        auto it = enumerated.begin();
+        EXPECT_EQ(*(it++), (std::tuple{0, 1}));
+        EXPECT_EQ(*it, (std::tuple{1, 2}));
+
         TVector<size_t> d = {0, 0, 0};
         FOR_DISPATCH_2(i, x, Enumerate(d)) {
             x = i;