allocator.h 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <cstdint>
#include "paddle/fluid/platform/place.h"
19
#include "paddle/pten/core/candidate/allocator.h"
20 21

namespace pten {
22
namespace deprecated {
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59

/// \brief Encapsulates strategies for access/addressing, allocation/
/// deallocation and construction/destruction of objects.
class RawAllocator {
 public:
  using Place = paddle::platform::Place;

  /// \brief Default destructor.
  virtual ~RawAllocator() = default;

  /// \brief Allocates storage suitable for an array object of n bytes
  /// and creates the array, but does not construct array elements.
  /// May throw exceptions.
  /// \param bytes_size The number of bytes to allocate.
  /// \return The first address allocated.
  virtual void* Allocate(size_t bytes_size) = 0;

  /// \brief Deallocates storage pointed to ptr, which must be a value
  /// returned by a previous call to allocate that has not been
  /// invalidated by an intervening call to deallocate. The bytes_size
  /// must match the value previously passed to allocate.
  /// \param ptr The first address to deallocate.
  /// \param bytes_size The number of bytes to deallocate.
  virtual void Deallocate(void* ptr, size_t bytes_size) = 0;

  /// \brief Get the place value of the allocator and the allocation.
  /// \return The place value of the allocator and the allocation.
  virtual const Place& place() const = 0;
};

/// \brief Fancy pointer with context. The use of this data type
/// is to be compatible with allocators from different frameworks
/// without significant performance loss. This class does not
/// support being inherited.
class Allocation final {
 public:
  using Place = paddle::platform::Place;
60
  using DeleterFnPtr = void (*)(Allocation*);
61 62 63

  Allocation() = default;

64
  // Don't own resources, only provide access.
65 66
  Allocation(void* data, const Place& place) : data_(data), place_(place) {}

67 68 69
  // Own resources.
  Allocation(void* data, void* ctx, DeleterFnPtr deleter, const Place& place)
      : data_(data), ctx_(ctx), deleter_(deleter), place_(place) {}
70

71 72 73 74 75 76 77 78 79 80
  Allocation(Allocation&& other) { swap(*this, other); }
  Allocation& operator=(Allocation&& other) {
    // Exchange them explicitly to avoid moving is equivalent
    // to copying.
    swap(*this, other);
    return *this;
  }
  ~Allocation() { Clear(); }

  void* ptr() const noexcept { return data_; }
81
  void* operator->() const noexcept { return data_; }
82
  operator bool() const noexcept { return data_ || ctx_; }
83 84
  const Place& place() const noexcept { return place_; }

石晓伟 已提交
85
  void Clear() {
86 87 88 89 90
    if (deleter_) {
      deleter_(this);
    }
    ctx_ = nullptr;
    deleter_ = nullptr;
石晓伟 已提交
91
    data_ = nullptr;
92 93
  }

94 95 96 97 98 99 100
  DeleterFnPtr deleter() const noexcept { return deleter_; }

  template <typename T>
  T* CastContextWithoutCheck() const noexcept {
    return static_cast<T*>(ctx_);
  }

101 102 103 104 105 106 107
  /// \brief Statically cast the void pointer of the context object to
  /// the primitive type. Conversion of any pointer to void* and back
  /// to pointer to the original cv type preserves its original value.
  /// \param T The primitive type name of the context pointer.
  /// \param expected_deleter The destructor passed in to enhance type
  /// safety checking.
  template <typename T>
108 109 110 111 112 113 114 115
  T* CastContext(DeleterFnPtr expected_deleter) const {
    PADDLE_ENFORCE_EQ(
        deleter_ == expected_deleter,
        true,
        paddle::platform::errors::InvalidArgument(
            "The deleter of the allocation does not match, so the pointer "
            "cannot be safely removed."));
    return CastContextWithoutCheck<T>();
116 117 118
  }

 private:
119
  friend void swap(Allocation& a, Allocation& b) noexcept;
120
  void* data_{nullptr};
121 122
  void* ctx_{nullptr};
  DeleterFnPtr deleter_{nullptr};
123 124 125 126 127
  // TODO(Shixiaowei02): Enum needs to be used instead to reduce
  // the construction overhead by more than 50%.
  Place place_;
};

128 129
inline void swap(Allocation& a, Allocation& b) noexcept {
  ::std::swap(a.data_, b.data_);
130 131
  ::std::swap(a.ctx_, b.ctx_);
  ::std::swap(a.deleter_, b.deleter_);
132
  ::std::swap(a.place_, b.place_);
133 134 135 136 137 138
}

/// \brief Context compatible allocator interface. This allocator is
/// mainly used for general data structures such as Tensor. The raw
/// allocator is more universal and efficient.
class Allocator {
139 140
  using Place = paddle::platform::Place;

141 142 143
 public:
  virtual ~Allocator() = default;
  virtual Allocation Allocate(size_t bytes_size) = 0;
144
  virtual const Place& place() = 0;
145 146 147 148 149 150 151
};

inline Allocation Allocate(const std::shared_ptr<Allocator>& a, size_t n) {
  CHECK(a);
  return a->Allocate(n);
}

152
}  // namespace deprecated
153
}  // namespace pten