/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include "paddle/pten/core/allocator.h" namespace pten { namespace tests { class HostAllocatorSample : public pten::RawAllocator { public: using Place = paddle::platform::Place; void* Allocate(size_t bytes_size) override { return ::operator new(bytes_size); } void Deallocate(void* ptr, size_t bytes_size) override { return ::operator delete(ptr); } const Place& place() const override { return place_; } private: Place place_{paddle::platform::CPUPlace()}; }; class FancyAllocator : public pten::Allocator { public: static void Delete(void* data) { ::operator delete(data); } Allocation Allocate(size_t bytes_size) override { void* data = ::operator new(bytes_size); return Allocation(data, data, &Delete, paddle::platform::CPUPlace()); } }; template struct CustomAllocator { using value_type = T; using Allocator = pten::RawAllocator; explicit CustomAllocator(const std::shared_ptr& a) noexcept : alloc_(a) {} CustomAllocator(const CustomAllocator&) noexcept = default; T* allocate(std::size_t n) { return static_cast(alloc_->Allocate(n * sizeof(T))); } void deallocate(T* p, std::size_t n) { return alloc_->Deallocate(p, sizeof(T) * n); } template friend bool operator==(const CustomAllocator&, const CustomAllocator&) noexcept; template friend bool operator!=(const CustomAllocator&, const CustomAllocator&) noexcept; private: std::shared_ptr alloc_; }; template inline bool operator==(const CustomAllocator& lhs, const CustomAllocator& rhs) noexcept { return &lhs.alloc_ == &rhs.alloc_; } template inline bool operator!=(const CustomAllocator& lhs, const CustomAllocator& rhs) noexcept { return &lhs.alloc_ != &rhs.alloc_; } } // namespace tests } // namespace pten