weak_ptr.h 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. #pragma once
  2. #include "ref_counted.h"
  3. #include <util/generic/hash.h>
  4. namespace NYT {
  5. ////////////////////////////////////////////////////////////////////////////////
  6. template <class T>
  7. class TWeakPtr
  8. {
  9. public:
  10. using TUnderlying = T;
  11. //! Empty constructor.
  12. TWeakPtr() = default;
  13. TWeakPtr(std::nullptr_t)
  14. { }
  15. //! Constructor from an unqualified reference.
  16. /*!
  17. * Note that this constructor could be racy due to unsynchronized operations
  18. * on the object and on the counter.
  19. */
  20. explicit TWeakPtr(T* p) noexcept
  21. : T_(p)
  22. {
  23. #if defined(_tsan_enabled_)
  24. if (T_) {
  25. RefCounter_ = GetRefCounter(T_);
  26. }
  27. #endif
  28. AcquireRef();
  29. }
  30. //! Constructor from a strong reference.
  31. TWeakPtr(const TIntrusivePtr<T>& ptr) noexcept
  32. : TWeakPtr(ptr.Get())
  33. { }
  34. //! Constructor from a strong reference with an upcast.
  35. template <class U, class = typename std::enable_if_t<std::is_convertible_v<U*, T*>>>
  36. TWeakPtr(const TIntrusivePtr<U>& ptr) noexcept
  37. : TWeakPtr(ptr.Get())
  38. {
  39. static_assert(
  40. std::derived_from<T, TRefCountedBase>,
  41. "Cast allowed only for types derived from TRefCountedBase");
  42. }
  43. //! Copy constructor.
  44. TWeakPtr(const TWeakPtr& other) noexcept
  45. : TWeakPtr(other.T_)
  46. { }
  47. //! Copy constructor with an upcast.
  48. template <class U, class = typename std::enable_if_t<std::is_convertible_v<U*, T*>>>
  49. TWeakPtr(const TWeakPtr<U>& other) noexcept
  50. : TWeakPtr(other.Lock())
  51. {
  52. static_assert(
  53. std::derived_from<T, TRefCountedBase>,
  54. "Cast allowed only for types derived from TRefCountedBase");
  55. }
  56. //! Move constructor.
  57. TWeakPtr(TWeakPtr&& other) noexcept
  58. {
  59. other.Swap(*this);
  60. }
  61. //! Move constructor with an upcast.
  62. template <class U, class = typename std::enable_if_t<std::is_convertible_v<U*, T*>>>
  63. TWeakPtr(TWeakPtr<U>&& other) noexcept
  64. {
  65. static_assert(
  66. std::derived_from<T, TRefCountedBase>,
  67. "Cast allowed only for types derived from TRefCountedBase");
  68. TIntrusivePtr<U> strongOther = other.Lock();
  69. if (strongOther) {
  70. T_ = other.T_;
  71. other.T_ = nullptr;
  72. #if defined(_tsan_enabled_)
  73. RefCounter_ = other.RefCounter_;
  74. other.RefCounter_ = nullptr;
  75. #endif
  76. }
  77. }
  78. //! Destructor.
  79. ~TWeakPtr()
  80. {
  81. ReleaseRef();
  82. }
  83. //! Assignment operator from a strong reference.
  84. template <class U>
  85. TWeakPtr& operator=(const TIntrusivePtr<U>& ptr) noexcept
  86. {
  87. static_assert(
  88. std::is_convertible_v<U*, T*>,
  89. "U* must be convertible to T*");
  90. TWeakPtr(ptr).Swap(*this);
  91. return *this;
  92. }
  93. //! Copy assignment operator.
  94. TWeakPtr& operator=(const TWeakPtr& other) noexcept
  95. {
  96. TWeakPtr(other).Swap(*this);
  97. return *this;
  98. }
  99. //! Copy assignment operator with an upcast.
  100. template <class U>
  101. TWeakPtr& operator=(const TWeakPtr<U>& other) noexcept
  102. {
  103. static_assert(
  104. std::is_convertible_v<U*, T*>,
  105. "U* must be convertible to T*");
  106. TWeakPtr(other).Swap(*this);
  107. return *this;
  108. }
  109. //! Move assignment operator.
  110. TWeakPtr& operator=(TWeakPtr&& other) noexcept
  111. {
  112. other.Swap(*this);
  113. return *this;
  114. }
  115. //! Move assignment operator with an upcast.
  116. template <class U>
  117. TWeakPtr& operator=(TWeakPtr<U>&& other) noexcept
  118. {
  119. static_assert(
  120. std::is_convertible_v<U*, T*>,
  121. "U* must be convertible to T*");
  122. TWeakPtr(std::move(other)).Swap(*this);
  123. return *this;
  124. }
  125. //! Drop the pointer.
  126. void Reset() // noexcept
  127. {
  128. TWeakPtr().Swap(*this);
  129. }
  130. //! Replace the pointer with a specified one.
  131. void Reset(T* p) // noexcept
  132. {
  133. TWeakPtr(p).Swap(*this);
  134. }
  135. //! Replace the pointer with a specified one.
  136. template <class U>
  137. void Reset(const TIntrusivePtr<U>& ptr) // noexcept
  138. {
  139. static_assert(
  140. std::is_convertible_v<U*, T*>,
  141. "U* must be convertible to T*");
  142. TWeakPtr(ptr).Swap(*this);
  143. }
  144. //! Swap the pointer with the other one.
  145. void Swap(TWeakPtr& other) noexcept
  146. {
  147. DoSwap(T_, other.T_);
  148. #if defined(_tsan_enabled_)
  149. DoSwap(RefCounter_, other.RefCounter_);
  150. #endif
  151. }
  152. //! Acquire a strong reference to the pointee and return a strong pointer.
  153. TIntrusivePtr<T> Lock() const noexcept
  154. {
  155. return T_ && RefCounter()->TryRef()
  156. ? TIntrusivePtr<T>(T_, false)
  157. : TIntrusivePtr<T>();
  158. }
  159. bool IsExpired() const noexcept
  160. {
  161. return !T_ || (RefCounter()->GetRefCount() == 0);
  162. }
  163. const TRefCounter* TryGetRefCounter() const
  164. {
  165. return T_
  166. ? RefCounter()
  167. : nullptr;
  168. }
  169. private:
  170. void AcquireRef()
  171. {
  172. if (T_) {
  173. RefCounter()->WeakRef();
  174. }
  175. }
  176. void ReleaseRef()
  177. {
  178. if (T_) {
  179. // Support incomplete type.
  180. if (RefCounter()->WeakUnref()) {
  181. DeallocateRefCounted(T_);
  182. }
  183. }
  184. }
  185. template <class U>
  186. friend class TWeakPtr;
  187. template <class U>
  188. friend struct ::THash;
  189. T* T_ = nullptr;
  190. #if defined(_tsan_enabled_)
  191. const TRefCounter* RefCounter_ = nullptr;
  192. const TRefCounter* RefCounter() const
  193. {
  194. return RefCounter_;
  195. }
  196. #else
  197. const TRefCounter* RefCounter() const
  198. {
  199. return GetRefCounter(T_);
  200. }
  201. #endif
  202. };
  203. ////////////////////////////////////////////////////////////////////////////////
  204. //! Creates a weak pointer wrapper for a given raw pointer.
  205. //! Compared to |TWeakPtr<T>::ctor|, type inference enables omitting |T|.
  206. template <class T>
  207. TWeakPtr<T> MakeWeak(T* p)
  208. {
  209. return TWeakPtr<T>(p);
  210. }
  211. //! Creates a weak pointer wrapper for a given intrusive pointer.
  212. //! Compared to |TWeakPtr<T>::ctor|, type inference enables omitting |T|.
  213. template <class T>
  214. TWeakPtr<T> MakeWeak(const TIntrusivePtr<T>& p)
  215. {
  216. return TWeakPtr<T>(p);
  217. }
  218. //! A helper for acquiring weak pointer for pointee, resetting intrusive pointer and then
  219. //! returning the pointee reference count using the acquired weak pointer.
  220. //! This helper is designed for best effort in checking that the object is not leaked after
  221. //! destructing (what seems to be) the last pointer to it.
  222. //! NB: it is possible to rewrite this helper making it working event with intrinsic refcounted objects,
  223. //! but it requires much nastier integration with the intrusive pointer destruction routines.
  224. template <typename T>
  225. int ResetAndGetResidualRefCount(TIntrusivePtr<T>& pointer)
  226. {
  227. auto weakPointer = MakeWeak(pointer);
  228. pointer.Reset();
  229. pointer = weakPointer.Lock();
  230. if (pointer) {
  231. // This _may_ return 0 if we are again the only holder of the pointee.
  232. return pointer->GetRefCount() - 1;
  233. } else {
  234. return 0;
  235. }
  236. }
  237. ////////////////////////////////////////////////////////////////////////////////
  238. template <class T, class U>
  239. bool operator==(const TWeakPtr<T>& lhs, const TWeakPtr<U>& rhs)
  240. {
  241. static_assert(
  242. std::is_convertible_v<U*, T*>,
  243. "U* must be convertible to T*");
  244. return lhs.TryGetRefCounter() == rhs.TryGetRefCounter();
  245. }
  246. template <class T, class U>
  247. bool operator!=(const TWeakPtr<T>& lhs, const TWeakPtr<U>& rhs)
  248. {
  249. static_assert(
  250. std::is_convertible_v<U*, T*>,
  251. "U* must be convertible to T*");
  252. return lhs.TryGetRefCounter() != rhs.TryGetRefCounter();
  253. }
  254. template <class T>
  255. bool operator==(std::nullptr_t, const TWeakPtr<T>& rhs)
  256. {
  257. return nullptr == rhs.TryGetRefCounter();
  258. }
  259. template <class T>
  260. bool operator!=(std::nullptr_t, const TWeakPtr<T>& rhs)
  261. {
  262. return nullptr != rhs.TryGetRefCounter();
  263. }
  264. template <class T>
  265. bool operator==(const TWeakPtr<T>& lhs, std::nullptr_t)
  266. {
  267. return nullptr == lhs.TryGetRefCounter();
  268. }
  269. template <class T>
  270. bool operator!=(const TWeakPtr<T>& lhs, std::nullptr_t)
  271. {
  272. return nullptr != lhs.TryGetRefCounter();
  273. }
  274. ////////////////////////////////////////////////////////////////////////////////
  275. } // namespace NYT
  276. //! A hasher for TWeakPtr.
  277. template <class T>
  278. struct THash<NYT::TWeakPtr<T>>
  279. {
  280. size_t operator () (const NYT::TWeakPtr<T>& ptr) const
  281. {
  282. return THash<const NYT::TRefCountedBase*>()(ptr.T_);
  283. }
  284. };