FunctionExtras.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. /// \file
  14. /// This file provides a collection of function (or more generally, callable)
  15. /// type erasure utilities supplementing those provided by the standard library
  16. /// in `<function>`.
  17. ///
  18. /// It provides `unique_function`, which works like `std::function` but supports
  19. /// move-only callable objects and const-qualification.
  20. ///
  21. /// Future plans:
  22. /// - Add a `function` that provides ref-qualified support, which doesn't work
  23. /// with `std::function`.
  24. /// - Provide support for specifying multiple signatures to type erase callable
  25. /// objects with an overload set, such as those produced by generic lambdas.
  26. /// - Expand to include a copyable utility that directly replaces std::function
  27. /// but brings the above improvements.
  28. ///
  29. /// Note that LLVM's utilities are greatly simplified by not supporting
  30. /// allocators.
  31. ///
  32. /// If the standard library ever begins to provide comparable facilities we can
  33. /// consider switching to those.
  34. ///
  35. //===----------------------------------------------------------------------===//
  36. #ifndef LLVM_ADT_FUNCTIONEXTRAS_H
  37. #define LLVM_ADT_FUNCTIONEXTRAS_H
  38. #include "llvm/ADT/PointerIntPair.h"
  39. #include "llvm/ADT/PointerUnion.h"
  40. #include "llvm/ADT/STLForwardCompat.h"
  41. #include "llvm/Support/MemAlloc.h"
  42. #include "llvm/Support/type_traits.h"
  43. #include <cstring>
  44. #include <memory>
  45. #include <type_traits>
  46. namespace llvm {
  47. /// unique_function is a type-erasing functor similar to std::function.
  48. ///
  49. /// It can hold move-only function objects, like lambdas capturing unique_ptrs.
  50. /// Accordingly, it is movable but not copyable.
  51. ///
  52. /// It supports const-qualification:
  53. /// - unique_function<int() const> has a const operator().
  54. /// It can only hold functions which themselves have a const operator().
  55. /// - unique_function<int()> has a non-const operator().
  56. /// It can hold functions with a non-const operator(), like mutable lambdas.
  57. template <typename FunctionT> class unique_function;
  58. namespace detail {
  59. template <typename T>
  60. using EnableIfTrivial =
  61. std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
  62. std::is_trivially_destructible<T>::value>;
  63. template <typename CallableT, typename ThisT>
  64. using EnableUnlessSameType =
  65. std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>;
  66. template <typename CallableT, typename Ret, typename... Params>
  67. using EnableIfCallable = std::enable_if_t<std::disjunction<
  68. std::is_void<Ret>,
  69. std::is_same<decltype(std::declval<CallableT>()(std::declval<Params>()...)),
  70. Ret>,
  71. std::is_same<const decltype(std::declval<CallableT>()(
  72. std::declval<Params>()...)),
  73. Ret>,
  74. std::is_convertible<decltype(std::declval<CallableT>()(
  75. std::declval<Params>()...)),
  76. Ret>>::value>;
  77. template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
  78. protected:
  79. static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
  80. template <typename T, class = void>
  81. struct IsSizeLessThanThresholdT : std::false_type {};
  82. template <typename T>
  83. struct IsSizeLessThanThresholdT<
  84. T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {};
  85. // Provide a type function to map parameters that won't observe extra copies
  86. // or moves and which are small enough to likely pass in register to values
  87. // and all other types to l-value reference types. We use this to compute the
  88. // types used in our erased call utility to minimize copies and moves unless
  89. // doing so would force things unnecessarily into memory.
  90. //
  91. // The heuristic used is related to common ABI register passing conventions.
  92. // It doesn't have to be exact though, and in one way it is more strict
  93. // because we want to still be able to observe either moves *or* copies.
  94. template <typename T> struct AdjustedParamTBase {
  95. static_assert(!std::is_reference<T>::value,
  96. "references should be handled by template specialization");
  97. using type = std::conditional_t<
  98. llvm::is_trivially_copy_constructible<T>::value &&
  99. llvm::is_trivially_move_constructible<T>::value &&
  100. IsSizeLessThanThresholdT<T>::value,
  101. T, T &>;
  102. };
  103. // This specialization ensures that 'AdjustedParam<V<T>&>' or
  104. // 'AdjustedParam<V<T>&&>' does not trigger a compile-time error when 'T' is
  105. // an incomplete type and V a templated type.
  106. template <typename T> struct AdjustedParamTBase<T &> { using type = T &; };
  107. template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; };
  108. template <typename T>
  109. using AdjustedParamT = typename AdjustedParamTBase<T>::type;
  110. // The type of the erased function pointer we use as a callback to dispatch to
  111. // the stored callable when it is trivial to move and destroy.
  112. using CallPtrT = ReturnT (*)(void *CallableAddr,
  113. AdjustedParamT<ParamTs>... Params);
  114. using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr);
  115. using DestroyPtrT = void (*)(void *CallableAddr);
  116. /// A struct to hold a single trivial callback with sufficient alignment for
  117. /// our bitpacking.
  118. struct alignas(8) TrivialCallback {
  119. CallPtrT CallPtr;
  120. };
  121. /// A struct we use to aggregate three callbacks when we need full set of
  122. /// operations.
  123. struct alignas(8) NonTrivialCallbacks {
  124. CallPtrT CallPtr;
  125. MovePtrT MovePtr;
  126. DestroyPtrT DestroyPtr;
  127. };
  128. // Create a pointer union between either a pointer to a static trivial call
  129. // pointer in a struct or a pointer to a static struct of the call, move, and
  130. // destroy pointers.
  131. using CallbackPointerUnionT =
  132. PointerUnion<TrivialCallback *, NonTrivialCallbacks *>;
  133. // The main storage buffer. This will either have a pointer to out-of-line
  134. // storage or an inline buffer storing the callable.
  135. union StorageUnionT {
  136. // For out-of-line storage we keep a pointer to the underlying storage and
  137. // the size. This is enough to deallocate the memory.
  138. struct OutOfLineStorageT {
  139. void *StoragePtr;
  140. size_t Size;
  141. size_t Alignment;
  142. } OutOfLineStorage;
  143. static_assert(
  144. sizeof(OutOfLineStorageT) <= InlineStorageSize,
  145. "Should always use all of the out-of-line storage for inline storage!");
  146. // For in-line storage, we just provide an aligned character buffer. We
  147. // provide three pointers worth of storage here.
  148. // This is mutable as an inlined `const unique_function<void() const>` may
  149. // still modify its own mutable members.
  150. mutable std::aligned_storage_t<InlineStorageSize, alignof(void *)>
  151. InlineStorage;
  152. } StorageUnion;
  153. // A compressed pointer to either our dispatching callback or our table of
  154. // dispatching callbacks and the flag for whether the callable itself is
  155. // stored inline or not.
  156. PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag;
  157. bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); }
  158. bool isTrivialCallback() const {
  159. return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>();
  160. }
  161. CallPtrT getTrivialCallback() const {
  162. return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr;
  163. }
  164. NonTrivialCallbacks *getNonTrivialCallbacks() const {
  165. return CallbackAndInlineFlag.getPointer()
  166. .template get<NonTrivialCallbacks *>();
  167. }
  168. CallPtrT getCallPtr() const {
  169. return isTrivialCallback() ? getTrivialCallback()
  170. : getNonTrivialCallbacks()->CallPtr;
  171. }
  172. // These three functions are only const in the narrow sense. They return
  173. // mutable pointers to function state.
  174. // This allows unique_function<T const>::operator() to be const, even if the
  175. // underlying functor may be internally mutable.
  176. //
  177. // const callers must ensure they're only used in const-correct ways.
  178. void *getCalleePtr() const {
  179. return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
  180. }
  181. void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
  182. void *getOutOfLineStorage() const {
  183. return StorageUnion.OutOfLineStorage.StoragePtr;
  184. }
  185. size_t getOutOfLineStorageSize() const {
  186. return StorageUnion.OutOfLineStorage.Size;
  187. }
  188. size_t getOutOfLineStorageAlignment() const {
  189. return StorageUnion.OutOfLineStorage.Alignment;
  190. }
  191. void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) {
  192. StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
  193. }
  194. template <typename CalledAsT>
  195. static ReturnT CallImpl(void *CallableAddr,
  196. AdjustedParamT<ParamTs>... Params) {
  197. auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
  198. return Func(std::forward<ParamTs>(Params)...);
  199. }
  200. template <typename CallableT>
  201. static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept {
  202. new (LHSCallableAddr)
  203. CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr)));
  204. }
  205. template <typename CallableT>
  206. static void DestroyImpl(void *CallableAddr) noexcept {
  207. reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
  208. }
  209. // The pointers to call/move/destroy functions are determined for each
  210. // callable type (and called-as type, which determines the overload chosen).
  211. // (definitions are out-of-line).
  212. // By default, we need an object that contains all the different
  213. // type erased behaviors needed. Create a static instance of the struct type
  214. // here and each instance will contain a pointer to it.
  215. // Wrap in a struct to avoid https://gcc.gnu.org/PR71954
  216. template <typename CallableT, typename CalledAs, typename Enable = void>
  217. struct CallbacksHolder {
  218. static NonTrivialCallbacks Callbacks;
  219. };
  220. // See if we can create a trivial callback. We need the callable to be
  221. // trivially moved and trivially destroyed so that we don't have to store
  222. // type erased callbacks for those operations.
  223. template <typename CallableT, typename CalledAs>
  224. struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> {
  225. static TrivialCallback Callbacks;
  226. };
  227. // A simple tag type so the call-as type to be passed to the constructor.
  228. template <typename T> struct CalledAs {};
  229. // Essentially the "main" unique_function constructor, but subclasses
  230. // provide the qualified type to be used for the call.
  231. // (We always store a T, even if the call will use a pointer to const T).
  232. template <typename CallableT, typename CalledAsT>
  233. UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
  234. bool IsInlineStorage = true;
  235. void *CallableAddr = getInlineStorage();
  236. if (sizeof(CallableT) > InlineStorageSize ||
  237. alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
  238. IsInlineStorage = false;
  239. // Allocate out-of-line storage. FIXME: Use an explicit alignment
  240. // parameter in C++17 mode.
  241. auto Size = sizeof(CallableT);
  242. auto Alignment = alignof(CallableT);
  243. CallableAddr = allocate_buffer(Size, Alignment);
  244. setOutOfLineStorage(CallableAddr, Size, Alignment);
  245. }
  246. // Now move into the storage.
  247. new (CallableAddr) CallableT(std::move(Callable));
  248. CallbackAndInlineFlag.setPointerAndInt(
  249. &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage);
  250. }
  251. ~UniqueFunctionBase() {
  252. if (!CallbackAndInlineFlag.getPointer())
  253. return;
  254. // Cache this value so we don't re-check it after type-erased operations.
  255. bool IsInlineStorage = isInlineStorage();
  256. if (!isTrivialCallback())
  257. getNonTrivialCallbacks()->DestroyPtr(
  258. IsInlineStorage ? getInlineStorage() : getOutOfLineStorage());
  259. if (!IsInlineStorage)
  260. deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(),
  261. getOutOfLineStorageAlignment());
  262. }
  263. UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
  264. // Copy the callback and inline flag.
  265. CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
  266. // If the RHS is empty, just copying the above is sufficient.
  267. if (!RHS)
  268. return;
  269. if (!isInlineStorage()) {
  270. // The out-of-line case is easiest to move.
  271. StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage;
  272. } else if (isTrivialCallback()) {
  273. // Move is trivial, just memcpy the bytes across.
  274. memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize);
  275. } else {
  276. // Non-trivial move, so dispatch to a type-erased implementation.
  277. getNonTrivialCallbacks()->MovePtr(getInlineStorage(),
  278. RHS.getInlineStorage());
  279. }
  280. // Clear the old callback and inline flag to get back to as-if-null.
  281. RHS.CallbackAndInlineFlag = {};
  282. #ifndef NDEBUG
  283. // In debug builds, we also scribble across the rest of the storage.
  284. memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize);
  285. #endif
  286. }
  287. UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
  288. if (this == &RHS)
  289. return *this;
  290. // Because we don't try to provide any exception safety guarantees we can
  291. // implement move assignment very simply by first destroying the current
  292. // object and then move-constructing over top of it.
  293. this->~UniqueFunctionBase();
  294. new (this) UniqueFunctionBase(std::move(RHS));
  295. return *this;
  296. }
  297. UniqueFunctionBase() = default;
  298. public:
  299. explicit operator bool() const {
  300. return (bool)CallbackAndInlineFlag.getPointer();
  301. }
  302. };
  303. template <typename R, typename... P>
  304. template <typename CallableT, typename CalledAsT, typename Enable>
  305. typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase<
  306. R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = {
  307. &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
  308. template <typename R, typename... P>
  309. template <typename CallableT, typename CalledAsT>
  310. typename UniqueFunctionBase<R, P...>::TrivialCallback
  311. UniqueFunctionBase<R, P...>::CallbacksHolder<
  312. CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{
  313. &CallImpl<CalledAsT>};
  314. } // namespace detail
  315. template <typename R, typename... P>
  316. class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
  317. using Base = detail::UniqueFunctionBase<R, P...>;
  318. public:
  319. unique_function() = default;
  320. unique_function(std::nullptr_t) {}
  321. unique_function(unique_function &&) = default;
  322. unique_function(const unique_function &) = delete;
  323. unique_function &operator=(unique_function &&) = default;
  324. unique_function &operator=(const unique_function &) = delete;
  325. template <typename CallableT>
  326. unique_function(
  327. CallableT Callable,
  328. detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
  329. detail::EnableIfCallable<CallableT, R, P...> * = nullptr)
  330. : Base(std::forward<CallableT>(Callable),
  331. typename Base::template CalledAs<CallableT>{}) {}
  332. R operator()(P... Params) {
  333. return this->getCallPtr()(this->getCalleePtr(), Params...);
  334. }
  335. };
  336. template <typename R, typename... P>
  337. class unique_function<R(P...) const>
  338. : public detail::UniqueFunctionBase<R, P...> {
  339. using Base = detail::UniqueFunctionBase<R, P...>;
  340. public:
  341. unique_function() = default;
  342. unique_function(std::nullptr_t) {}
  343. unique_function(unique_function &&) = default;
  344. unique_function(const unique_function &) = delete;
  345. unique_function &operator=(unique_function &&) = default;
  346. unique_function &operator=(const unique_function &) = delete;
  347. template <typename CallableT>
  348. unique_function(
  349. CallableT Callable,
  350. detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
  351. detail::EnableIfCallable<const CallableT, R, P...> * = nullptr)
  352. : Base(std::forward<CallableT>(Callable),
  353. typename Base::template CalledAs<const CallableT>{}) {}
  354. R operator()(P... Params) const {
  355. return this->getCallPtr()(this->getCalleePtr(), Params...);
  356. }
  357. };
  358. } // end namespace llvm
  359. #endif // LLVM_ADT_FUNCTIONEXTRAS_H
  360. #ifdef __GNUC__
  361. #pragma GCC diagnostic pop
  362. #endif