FunctionExtras.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- FunctionExtras.h - Function type erasure utilities -------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. /// \file
  14. /// This file provides a collection of function (or more generally, callable)
  15. /// type erasure utilities supplementing those provided by the standard library
  16. /// in `<function>`.
  17. ///
  18. /// It provides `unique_function`, which works like `std::function` but supports
  19. /// move-only callable objects and const-qualification.
  20. ///
  21. /// Future plans:
  22. /// - Add a `function` that provides ref-qualified support, which doesn't work
  23. /// with `std::function`.
  24. /// - Provide support for specifying multiple signatures to type erase callable
  25. /// objects with an overload set, such as those produced by generic lambdas.
  26. /// - Expand to include a copyable utility that directly replaces std::function
  27. /// but brings the above improvements.
  28. ///
  29. /// Note that LLVM's utilities are greatly simplified by not supporting
  30. /// allocators.
  31. ///
  32. /// If the standard library ever begins to provide comparable facilities we can
  33. /// consider switching to those.
  34. ///
  35. //===----------------------------------------------------------------------===//
  36. #ifndef LLVM_ADT_FUNCTIONEXTRAS_H
  37. #define LLVM_ADT_FUNCTIONEXTRAS_H
  38. #include "llvm/ADT/PointerIntPair.h"
  39. #include "llvm/ADT/PointerUnion.h"
  40. #include "llvm/ADT/STLForwardCompat.h"
  41. #include "llvm/Support/MemAlloc.h"
  42. #include "llvm/Support/type_traits.h"
  43. #include <cstring>
  44. #include <memory>
  45. #include <type_traits>
  46. namespace llvm {
  47. /// unique_function is a type-erasing functor similar to std::function.
  48. ///
  49. /// It can hold move-only function objects, like lambdas capturing unique_ptrs.
  50. /// Accordingly, it is movable but not copyable.
  51. ///
  52. /// It supports const-qualification:
  53. /// - unique_function<int() const> has a const operator().
  54. /// It can only hold functions which themselves have a const operator().
  55. /// - unique_function<int()> has a non-const operator().
  56. /// It can hold functions with a non-const operator(), like mutable lambdas.
  57. template <typename FunctionT> class unique_function;
  58. namespace detail {
  59. template <typename T>
  60. using EnableIfTrivial =
  61. std::enable_if_t<llvm::is_trivially_move_constructible<T>::value &&
  62. std::is_trivially_destructible<T>::value>;
  63. template <typename CallableT, typename ThisT>
  64. using EnableUnlessSameType =
  65. std::enable_if_t<!std::is_same<remove_cvref_t<CallableT>, ThisT>::value>;
  66. template <typename CallableT, typename Ret, typename... Params>
  67. using EnableIfCallable = std::enable_if_t<llvm::disjunction<
  68. std::is_void<Ret>,
  69. std::is_same<decltype(std::declval<CallableT>()(std::declval<Params>()...)),
  70. Ret>,
  71. std::is_same<const decltype(std::declval<CallableT>()(
  72. std::declval<Params>()...)),
  73. Ret>,
  74. std::is_convertible<decltype(std::declval<CallableT>()(
  75. std::declval<Params>()...)),
  76. Ret>>::value>;
  77. template <typename ReturnT, typename... ParamTs> class UniqueFunctionBase {
  78. protected:
  79. static constexpr size_t InlineStorageSize = sizeof(void *) * 3;
  80. template <typename T, class = void>
  81. struct IsSizeLessThanThresholdT : std::false_type {};
  82. template <typename T>
  83. struct IsSizeLessThanThresholdT<
  84. T, std::enable_if_t<sizeof(T) <= 2 * sizeof(void *)>> : std::true_type {};
  85. // Provide a type function to map parameters that won't observe extra copies
  86. // or moves and which are small enough to likely pass in register to values
  87. // and all other types to l-value reference types. We use this to compute the
  88. // types used in our erased call utility to minimize copies and moves unless
  89. // doing so would force things unnecessarily into memory.
  90. //
  91. // The heuristic used is related to common ABI register passing conventions.
  92. // It doesn't have to be exact though, and in one way it is more strict
  93. // because we want to still be able to observe either moves *or* copies.
  94. template <typename T> struct AdjustedParamTBase {
  95. static_assert(!std::is_reference<T>::value,
  96. "references should be handled by template specialization");
  97. using type = typename std::conditional<
  98. llvm::is_trivially_copy_constructible<T>::value &&
  99. llvm::is_trivially_move_constructible<T>::value &&
  100. IsSizeLessThanThresholdT<T>::value,
  101. T, T &>::type;
  102. };
  103. // This specialization ensures that 'AdjustedParam<V<T>&>' or
  104. // 'AdjustedParam<V<T>&&>' does not trigger a compile-time error when 'T' is
  105. // an incomplete type and V a templated type.
  106. template <typename T> struct AdjustedParamTBase<T &> { using type = T &; };
  107. template <typename T> struct AdjustedParamTBase<T &&> { using type = T &; };
  108. template <typename T>
  109. using AdjustedParamT = typename AdjustedParamTBase<T>::type;
  110. // The type of the erased function pointer we use as a callback to dispatch to
  111. // the stored callable when it is trivial to move and destroy.
  112. using CallPtrT = ReturnT (*)(void *CallableAddr,
  113. AdjustedParamT<ParamTs>... Params);
  114. using MovePtrT = void (*)(void *LHSCallableAddr, void *RHSCallableAddr);
  115. using DestroyPtrT = void (*)(void *CallableAddr);
  116. /// A struct to hold a single trivial callback with sufficient alignment for
  117. /// our bitpacking.
  118. struct alignas(8) TrivialCallback {
  119. CallPtrT CallPtr;
  120. };
  121. /// A struct we use to aggregate three callbacks when we need full set of
  122. /// operations.
  123. struct alignas(8) NonTrivialCallbacks {
  124. CallPtrT CallPtr;
  125. MovePtrT MovePtr;
  126. DestroyPtrT DestroyPtr;
  127. };
  128. // Create a pointer union between either a pointer to a static trivial call
  129. // pointer in a struct or a pointer to a static struct of the call, move, and
  130. // destroy pointers.
  131. using CallbackPointerUnionT =
  132. PointerUnion<TrivialCallback *, NonTrivialCallbacks *>;
  133. // The main storage buffer. This will either have a pointer to out-of-line
  134. // storage or an inline buffer storing the callable.
  135. union StorageUnionT {
  136. // For out-of-line storage we keep a pointer to the underlying storage and
  137. // the size. This is enough to deallocate the memory.
  138. struct OutOfLineStorageT {
  139. void *StoragePtr;
  140. size_t Size;
  141. size_t Alignment;
  142. } OutOfLineStorage;
  143. static_assert(
  144. sizeof(OutOfLineStorageT) <= InlineStorageSize,
  145. "Should always use all of the out-of-line storage for inline storage!");
  146. // For in-line storage, we just provide an aligned character buffer. We
  147. // provide three pointers worth of storage here.
  148. // This is mutable as an inlined `const unique_function<void() const>` may
  149. // still modify its own mutable members.
  150. mutable
  151. typename std::aligned_storage<InlineStorageSize, alignof(void *)>::type
  152. InlineStorage;
  153. } StorageUnion;
  154. // A compressed pointer to either our dispatching callback or our table of
  155. // dispatching callbacks and the flag for whether the callable itself is
  156. // stored inline or not.
  157. PointerIntPair<CallbackPointerUnionT, 1, bool> CallbackAndInlineFlag;
  158. bool isInlineStorage() const { return CallbackAndInlineFlag.getInt(); }
  159. bool isTrivialCallback() const {
  160. return CallbackAndInlineFlag.getPointer().template is<TrivialCallback *>();
  161. }
  162. CallPtrT getTrivialCallback() const {
  163. return CallbackAndInlineFlag.getPointer().template get<TrivialCallback *>()->CallPtr;
  164. }
  165. NonTrivialCallbacks *getNonTrivialCallbacks() const {
  166. return CallbackAndInlineFlag.getPointer()
  167. .template get<NonTrivialCallbacks *>();
  168. }
  169. CallPtrT getCallPtr() const {
  170. return isTrivialCallback() ? getTrivialCallback()
  171. : getNonTrivialCallbacks()->CallPtr;
  172. }
  173. // These three functions are only const in the narrow sense. They return
  174. // mutable pointers to function state.
  175. // This allows unique_function<T const>::operator() to be const, even if the
  176. // underlying functor may be internally mutable.
  177. //
  178. // const callers must ensure they're only used in const-correct ways.
  179. void *getCalleePtr() const {
  180. return isInlineStorage() ? getInlineStorage() : getOutOfLineStorage();
  181. }
  182. void *getInlineStorage() const { return &StorageUnion.InlineStorage; }
  183. void *getOutOfLineStorage() const {
  184. return StorageUnion.OutOfLineStorage.StoragePtr;
  185. }
  186. size_t getOutOfLineStorageSize() const {
  187. return StorageUnion.OutOfLineStorage.Size;
  188. }
  189. size_t getOutOfLineStorageAlignment() const {
  190. return StorageUnion.OutOfLineStorage.Alignment;
  191. }
  192. void setOutOfLineStorage(void *Ptr, size_t Size, size_t Alignment) {
  193. StorageUnion.OutOfLineStorage = {Ptr, Size, Alignment};
  194. }
  195. template <typename CalledAsT>
  196. static ReturnT CallImpl(void *CallableAddr,
  197. AdjustedParamT<ParamTs>... Params) {
  198. auto &Func = *reinterpret_cast<CalledAsT *>(CallableAddr);
  199. return Func(std::forward<ParamTs>(Params)...);
  200. }
  201. template <typename CallableT>
  202. static void MoveImpl(void *LHSCallableAddr, void *RHSCallableAddr) noexcept {
  203. new (LHSCallableAddr)
  204. CallableT(std::move(*reinterpret_cast<CallableT *>(RHSCallableAddr)));
  205. }
  206. template <typename CallableT>
  207. static void DestroyImpl(void *CallableAddr) noexcept {
  208. reinterpret_cast<CallableT *>(CallableAddr)->~CallableT();
  209. }
  210. // The pointers to call/move/destroy functions are determined for each
  211. // callable type (and called-as type, which determines the overload chosen).
  212. // (definitions are out-of-line).
  213. // By default, we need an object that contains all the different
  214. // type erased behaviors needed. Create a static instance of the struct type
  215. // here and each instance will contain a pointer to it.
  216. // Wrap in a struct to avoid https://gcc.gnu.org/PR71954
  217. template <typename CallableT, typename CalledAs, typename Enable = void>
  218. struct CallbacksHolder {
  219. static NonTrivialCallbacks Callbacks;
  220. };
  221. // See if we can create a trivial callback. We need the callable to be
  222. // trivially moved and trivially destroyed so that we don't have to store
  223. // type erased callbacks for those operations.
  224. template <typename CallableT, typename CalledAs>
  225. struct CallbacksHolder<CallableT, CalledAs, EnableIfTrivial<CallableT>> {
  226. static TrivialCallback Callbacks;
  227. };
  228. // A simple tag type so the call-as type to be passed to the constructor.
  229. template <typename T> struct CalledAs {};
  230. // Essentially the "main" unique_function constructor, but subclasses
  231. // provide the qualified type to be used for the call.
  232. // (We always store a T, even if the call will use a pointer to const T).
  233. template <typename CallableT, typename CalledAsT>
  234. UniqueFunctionBase(CallableT Callable, CalledAs<CalledAsT>) {
  235. bool IsInlineStorage = true;
  236. void *CallableAddr = getInlineStorage();
  237. if (sizeof(CallableT) > InlineStorageSize ||
  238. alignof(CallableT) > alignof(decltype(StorageUnion.InlineStorage))) {
  239. IsInlineStorage = false;
  240. // Allocate out-of-line storage. FIXME: Use an explicit alignment
  241. // parameter in C++17 mode.
  242. auto Size = sizeof(CallableT);
  243. auto Alignment = alignof(CallableT);
  244. CallableAddr = allocate_buffer(Size, Alignment);
  245. setOutOfLineStorage(CallableAddr, Size, Alignment);
  246. }
  247. // Now move into the storage.
  248. new (CallableAddr) CallableT(std::move(Callable));
  249. CallbackAndInlineFlag.setPointerAndInt(
  250. &CallbacksHolder<CallableT, CalledAsT>::Callbacks, IsInlineStorage);
  251. }
  252. ~UniqueFunctionBase() {
  253. if (!CallbackAndInlineFlag.getPointer())
  254. return;
  255. // Cache this value so we don't re-check it after type-erased operations.
  256. bool IsInlineStorage = isInlineStorage();
  257. if (!isTrivialCallback())
  258. getNonTrivialCallbacks()->DestroyPtr(
  259. IsInlineStorage ? getInlineStorage() : getOutOfLineStorage());
  260. if (!IsInlineStorage)
  261. deallocate_buffer(getOutOfLineStorage(), getOutOfLineStorageSize(),
  262. getOutOfLineStorageAlignment());
  263. }
  264. UniqueFunctionBase(UniqueFunctionBase &&RHS) noexcept {
  265. // Copy the callback and inline flag.
  266. CallbackAndInlineFlag = RHS.CallbackAndInlineFlag;
  267. // If the RHS is empty, just copying the above is sufficient.
  268. if (!RHS)
  269. return;
  270. if (!isInlineStorage()) {
  271. // The out-of-line case is easiest to move.
  272. StorageUnion.OutOfLineStorage = RHS.StorageUnion.OutOfLineStorage;
  273. } else if (isTrivialCallback()) {
  274. // Move is trivial, just memcpy the bytes across.
  275. memcpy(getInlineStorage(), RHS.getInlineStorage(), InlineStorageSize);
  276. } else {
  277. // Non-trivial move, so dispatch to a type-erased implementation.
  278. getNonTrivialCallbacks()->MovePtr(getInlineStorage(),
  279. RHS.getInlineStorage());
  280. }
  281. // Clear the old callback and inline flag to get back to as-if-null.
  282. RHS.CallbackAndInlineFlag = {};
  283. #ifndef NDEBUG
  284. // In debug builds, we also scribble across the rest of the storage.
  285. memset(RHS.getInlineStorage(), 0xAD, InlineStorageSize);
  286. #endif
  287. }
  288. UniqueFunctionBase &operator=(UniqueFunctionBase &&RHS) noexcept {
  289. if (this == &RHS)
  290. return *this;
  291. // Because we don't try to provide any exception safety guarantees we can
  292. // implement move assignment very simply by first destroying the current
  293. // object and then move-constructing over top of it.
  294. this->~UniqueFunctionBase();
  295. new (this) UniqueFunctionBase(std::move(RHS));
  296. return *this;
  297. }
  298. UniqueFunctionBase() = default;
  299. public:
  300. explicit operator bool() const {
  301. return (bool)CallbackAndInlineFlag.getPointer();
  302. }
  303. };
  304. template <typename R, typename... P>
  305. template <typename CallableT, typename CalledAsT, typename Enable>
  306. typename UniqueFunctionBase<R, P...>::NonTrivialCallbacks UniqueFunctionBase<
  307. R, P...>::CallbacksHolder<CallableT, CalledAsT, Enable>::Callbacks = {
  308. &CallImpl<CalledAsT>, &MoveImpl<CallableT>, &DestroyImpl<CallableT>};
  309. template <typename R, typename... P>
  310. template <typename CallableT, typename CalledAsT>
  311. typename UniqueFunctionBase<R, P...>::TrivialCallback
  312. UniqueFunctionBase<R, P...>::CallbacksHolder<
  313. CallableT, CalledAsT, EnableIfTrivial<CallableT>>::Callbacks{
  314. &CallImpl<CalledAsT>};
  315. } // namespace detail
  316. template <typename R, typename... P>
  317. class unique_function<R(P...)> : public detail::UniqueFunctionBase<R, P...> {
  318. using Base = detail::UniqueFunctionBase<R, P...>;
  319. public:
  320. unique_function() = default;
  321. unique_function(std::nullptr_t) {}
  322. unique_function(unique_function &&) = default;
  323. unique_function(const unique_function &) = delete;
  324. unique_function &operator=(unique_function &&) = default;
  325. unique_function &operator=(const unique_function &) = delete;
  326. template <typename CallableT>
  327. unique_function(
  328. CallableT Callable,
  329. detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
  330. detail::EnableIfCallable<CallableT, R, P...> * = nullptr)
  331. : Base(std::forward<CallableT>(Callable),
  332. typename Base::template CalledAs<CallableT>{}) {}
  333. R operator()(P... Params) {
  334. return this->getCallPtr()(this->getCalleePtr(), Params...);
  335. }
  336. };
  337. template <typename R, typename... P>
  338. class unique_function<R(P...) const>
  339. : public detail::UniqueFunctionBase<R, P...> {
  340. using Base = detail::UniqueFunctionBase<R, P...>;
  341. public:
  342. unique_function() = default;
  343. unique_function(std::nullptr_t) {}
  344. unique_function(unique_function &&) = default;
  345. unique_function(const unique_function &) = delete;
  346. unique_function &operator=(unique_function &&) = default;
  347. unique_function &operator=(const unique_function &) = delete;
  348. template <typename CallableT>
  349. unique_function(
  350. CallableT Callable,
  351. detail::EnableUnlessSameType<CallableT, unique_function> * = nullptr,
  352. detail::EnableIfCallable<const CallableT, R, P...> * = nullptr)
  353. : Base(std::forward<CallableT>(Callable),
  354. typename Base::template CalledAs<const CallableT>{}) {}
  355. R operator()(P... Params) const {
  356. return this->getCallPtr()(this->getCalleePtr(), Params...);
  357. }
  358. };
  359. } // end namespace llvm
  360. #endif // LLVM_ADT_FUNCTIONEXTRAS_H
  361. #ifdef __GNUC__
  362. #pragma GCC diagnostic pop
  363. #endif