#pragma once #include #include #include template class TDisjointIntervalTree { private: static_assert(std::is_integral::value, "expect std::is_integral::value"); using TTree = TMap; // [key, value) using TIterator = typename TTree::iterator; using TConstIterator = typename TTree::const_iterator; using TReverseIterator = typename TTree::reverse_iterator; using TThis = TDisjointIntervalTree; TTree Tree; size_t NumElements; public: TDisjointIntervalTree() : NumElements() { } void Insert(const T t) { InsertInterval(t, t + 1); } // we assume that none of elements from [begin, end) belong to tree. void InsertInterval(const T begin, const T end) { InsertIntervalImpl(begin, end); NumElements += (size_t)(end - begin); } bool Has(const T t) const { return const_cast(this)->FindContaining(t) != Tree.end(); } bool Intersects(const T begin, const T end) { if (Empty()) { return false; } TIterator l = Tree.lower_bound(begin); if (l != Tree.end()) { if (l->first < end) { return true; } else if (l != Tree.begin()) { --l; return l->second > begin; } else { return false; } } else { auto last = Tree.rbegin(); return begin < last->second; } } TConstIterator FindContaining(const T t) const { return const_cast(this)->FindContaining(t); } // Erase element. Returns true when element has been deleted, otherwise false. bool Erase(const T t) { TIterator n = FindContaining(t); if (n == Tree.end()) { return false; } --NumElements; T& begin = const_cast(n->first); T& end = const_cast(n->second); // Optimization hack. if (t == begin) { if (++begin == end) { // OK to change key since intervals do not intersect. Tree.erase(n); return true; } } else if (t == end - 1) { --end; } else { const T e = end; end = t; InsertIntervalImpl(t + 1, e); } Y_ASSERT(begin < end); return true; } // Erase interval. Returns number of elements removed from set. size_t EraseInterval(const T begin, const T end) { Y_ASSERT(begin < end); if (Empty()) { return 0; } size_t elementsRemoved = 0; TIterator completelyRemoveBegin = Tree.lower_bound(begin); if ((completelyRemoveBegin != Tree.end() && completelyRemoveBegin->first > begin && completelyRemoveBegin != Tree.begin()) || completelyRemoveBegin == Tree.end()) { // Look at the interval. It could contain [begin, end). TIterator containingBegin = completelyRemoveBegin; --containingBegin; if (containingBegin->first < begin && begin < containingBegin->second) { // Contains begin. if (containingBegin->second > end) { // Contains end. const T prevEnd = containingBegin->second; Y_ASSERT(containingBegin->second - begin <= NumElements); Y_ASSERT(containingBegin->second - containingBegin->first > end - begin); containingBegin->second = begin; InsertIntervalImpl(end, prevEnd); elementsRemoved = end - begin; NumElements -= elementsRemoved; return elementsRemoved; } else { elementsRemoved += containingBegin->second - begin; containingBegin->second = begin; } } } TIterator completelyRemoveEnd = completelyRemoveBegin != Tree.end() ? Tree.lower_bound(end) : Tree.end(); if (completelyRemoveEnd != Tree.begin() && (completelyRemoveEnd == Tree.end() || completelyRemoveEnd->first != end)) { TIterator containingEnd = completelyRemoveEnd; --containingEnd; if (containingEnd->second > end) { T& leftBorder = const_cast(containingEnd->first); Y_ASSERT(leftBorder < end); --completelyRemoveEnd; // Don't remove the whole interval. // Optimization hack. elementsRemoved += end - leftBorder; leftBorder = end; // OK to change key since intervals do not intersect. } } for (TIterator i = completelyRemoveBegin; i != completelyRemoveEnd; ++i) { elementsRemoved += i->second - i->first; } Tree.erase(completelyRemoveBegin, completelyRemoveEnd); Y_ASSERT(elementsRemoved <= NumElements); NumElements -= elementsRemoved; return elementsRemoved; } void Swap(TDisjointIntervalTree& rhv) { Tree.swap(rhv.Tree); std::swap(NumElements, rhv.NumElements); } void Clear() { Tree.clear(); NumElements = 0; } bool Empty() const { return Tree.empty(); } size_t GetNumElements() const { return NumElements; } size_t GetNumIntervals() const { return Tree.size(); } T Min() const { Y_ASSERT(!Empty()); return Tree.begin()->first; } T Max() const { Y_ASSERT(!Empty()); return Tree.rbegin()->second; } TConstIterator begin() const { return Tree.begin(); } TConstIterator end() const { return Tree.end(); } private: void InsertIntervalImpl(const T begin, const T end) { Y_ASSERT(begin < end); Y_ASSERT(!Intersects(begin, end)); TIterator l = Tree.lower_bound(begin); TIterator p = Tree.end(); if (l != Tree.begin()) { p = l; --p; } #ifndef NDEBUG TIterator u = Tree.upper_bound(begin); Y_DEBUG_ABORT_UNLESS(u == Tree.end() || u->first >= end, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, u->first, u->second); Y_DEBUG_ABORT_UNLESS(l == Tree.end() || l == u, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, l->first, l->second); Y_DEBUG_ABORT_UNLESS(p == Tree.end() || p->second <= begin, "Trying to add [%" PRIu64 ", %" PRIu64 ") which intersects with existing [%" PRIu64 ", %" PRIu64 ")", begin, end, p->first, p->second); #endif // try to extend interval if (p != Tree.end() && p->second == begin) { p->second = end; //Try to merge 2 intervals - p and next one if possible auto next = p; // Next is not Tree.end() here. ++next; if (next != Tree.end() && next->first == end) { p->second = next->second; Tree.erase(next); } // Maybe new interval extends right interval } else if (l != Tree.end() && end == l->first) { T& leftBorder = const_cast(l->first); // Optimization hack. leftBorder = begin; // OK to change key since intervals do not intersect. } else { Tree.insert(std::make_pair(begin, end)); } } TIterator FindContaining(const T t) { TIterator l = Tree.lower_bound(t); if (l != Tree.end()) { if (l->first == t) { return l; } Y_ASSERT(l->first > t); if (l == Tree.begin()) { return Tree.end(); } --l; Y_ASSERT(l->first != t); if (l->first < t && t < l->second) { return l; } } else if (!Tree.empty()) { // l is larger than Begin of any interval, but maybe it belongs to last interval? TReverseIterator last = Tree.rbegin(); Y_ASSERT(last->first != t); if (last->first < t && t < last->second) { return (++last).base(); } } return Tree.end(); } };