naive_managed_allocator.h 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include "paddle/fluid/memory/allocation/allocator.h"

namespace paddle {
namespace memory {
namespace allocation {

Y
Yu Yang 已提交
23 24 25 26 27
// An allocator to wrap an UnmanagedAllocator and make the allocation managed
// by C++ smart ptr.
//
// NOTE: if the NaiveManagedAllocator is destroyed before
// NaiveManagedAllocations, the allocation will never be released.
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
class NaiveManagedAllocator;
class NaiveManagedAllocation : public Allocation {
 public:
  NaiveManagedAllocation(std::unique_ptr<Allocation>&& underlying_allocation,
                         std::shared_ptr<NaiveManagedAllocator> allocator)
      : Allocation(underlying_allocation->ptr(), underlying_allocation->size(),
                   underlying_allocation->place()),
        underlying_allocation_(std::move(underlying_allocation)),
        allocator_(allocator) {}

  ~NaiveManagedAllocation() final;

 private:
  std::unique_ptr<Allocation> underlying_allocation_;
  std::weak_ptr<NaiveManagedAllocator> allocator_;
};

class NaiveManagedAllocator
    : public ManagedAllocator,
      public std::enable_shared_from_this<NaiveManagedAllocator> {
 public:
  template <typename... ARGS>
  static std::shared_ptr<ManagedAllocator> Create(ARGS... args) {
    return std::static_pointer_cast<ManagedAllocator>(
        std::shared_ptr<NaiveManagedAllocator>(
            new NaiveManagedAllocator(std::move(args)...)));
  }

  inline UnmanagedAllocator& UnderlyingAllocator() {
    return *underlying_allocator_;
  }

  bool IsAllocThreadSafe() const override;
  std::unique_ptr<Allocation> Allocate(size_t size,
                                       Attr attr = kDefault) override;
  std::shared_ptr<Allocation> AllocateShared(size_t size,
                                             Attr attr = kDefault) override;

 private:
  explicit NaiveManagedAllocator(std::unique_ptr<Allocator>&& allocator);
  explicit NaiveManagedAllocator(
      std::unique_ptr<UnmanagedAllocator>&& allocator);
  void Init(std::unique_ptr<UnmanagedAllocator>&& allocator);

  std::unique_ptr<UnmanagedAllocator> underlying_allocator_;
};
}  // namespace allocation
}  // namespace memory
}  // namespace paddle