allocator.h 2.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>

#include "paddle/pten/core/allocator.h"

namespace pten {
namespace tests {

class HostAllocatorSample : public pten::RawAllocator {
 public:
  using Place = paddle::platform::Place;
  void* Allocate(size_t bytes_size) override {
    return ::operator new(bytes_size);
  }
  void Deallocate(void* ptr, size_t bytes_size) override {
    return ::operator delete(ptr);
  }
  const Place& place() const override { return place_; }

 private:
  Place place_{paddle::platform::CPUPlace()};
};

class FancyAllocator : public pten::Allocator {
 public:
41 42 43
  static void Delete(Allocation* allocation) {
    ::operator delete(allocation->ptr());
  }
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

  Allocation Allocate(size_t bytes_size) override {
    void* data = ::operator new(bytes_size);
    return Allocation(data, data, &Delete, paddle::platform::CPUPlace());
  }
};

template <typename T>
struct CustomAllocator {
  using value_type = T;
  using Allocator = pten::RawAllocator;

  explicit CustomAllocator(const std::shared_ptr<Allocator>& a) noexcept
      : alloc_(a) {}

  CustomAllocator(const CustomAllocator&) noexcept = default;
  T* allocate(std::size_t n) {
    return static_cast<T*>(alloc_->Allocate(n * sizeof(T)));
  }
  void deallocate(T* p, std::size_t n) {
    return alloc_->Deallocate(p, sizeof(T) * n);
  }

  template <typename R, typename U>
  friend bool operator==(const CustomAllocator<R>&,
                         const CustomAllocator<U>&) noexcept;
  template <typename R, typename U>
  friend bool operator!=(const CustomAllocator<R>&,
                         const CustomAllocator<U>&) noexcept;

 private:
  std::shared_ptr<Allocator> alloc_;
};

template <typename T, typename U>
inline bool operator==(const CustomAllocator<T>& lhs,
                       const CustomAllocator<U>& rhs) noexcept {
  return &lhs.alloc_ == &rhs.alloc_;
}

template <typename T, typename U>
inline bool operator!=(const CustomAllocator<T>& lhs,
                       const CustomAllocator<U>& rhs) noexcept {
  return &lhs.alloc_ != &rhs.alloc_;
}

}  // namespace tests
}  // namespace pten