提交 3175317f 编写于 作者: Y Yu Yang

Add ZeroSize Allocator

上级 29f66c24
......@@ -34,6 +34,7 @@ endif()
cc_library(aligned_allocator SRCS aligned_allocator.cc DEPS allocator)
cc_library(auto_increment_allocator SRCS auto_increment_allocator.cc DEPS allocator)
cc_library(zero_size_allocator SRCS zero_size_allocator.cc DEPS allocator)
cc_library(allocator_facade SRCS allocator_facade.cc DEPS
${AllocatorFacadeDeps}
cpu_allocator
......@@ -42,6 +43,7 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
naive_managed_allocator
aligned_allocator
auto_increment_allocator
zero_size_allocator
cuda_device_guard)
nv_test(allocation_and_eigen_test SRCS allocation_and_eigen_test.cu DEPS allocator_facade)
......@@ -23,6 +23,7 @@
#include "paddle/fluid/memory/allocation/locked_allocator.h"
#include "paddle/fluid/memory/allocation/naive_managed_allocator.h"
#include "paddle/fluid/memory/allocation/pinned_allocator.h"
#include "paddle/fluid/memory/allocation/zero_size_allocator.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
......@@ -118,6 +119,7 @@ class AllocatorFacadePrivate {
AllocatorFacadePrivate() {
InitCPUAllocator();
InitCUDAAllocator();
WrapZeroSizeAllocator();
}
private:
......@@ -133,6 +135,13 @@ class AllocatorFacadePrivate {
}
#endif
}
void WrapZeroSizeAllocator() {
for (auto& pair : allocators_) {
pair.second =
std::make_shared<ZeroSizeAllocator>(pair.second, pair.first);
}
}
};
// Pimpl. Make interface clean.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/zero_size_allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
std::unique_ptr<Allocation> ZeroSizeAllocator::Allocate(size_t size,
Allocator::Attr attr) {
if (size == 0) {
return std::unique_ptr<Allocation>(new ZeroSizeAllocation(place_));
} else {
return underlying_allocator_->Allocate(size, attr);
}
}
std::shared_ptr<Allocation> ZeroSizeAllocator::AllocateShared(
size_t size, Allocator::Attr attr) {
if (size == 0) {
return std::shared_ptr<Allocation>(new ZeroSizeAllocation(place_));
} else {
return underlying_allocator_->AllocateShared(size, attr);
}
}
bool ZeroSizeAllocator::IsAllocThreadSafe() const { return true; }
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#pragma once
#include "paddle/fluid/memory/allocation/allocator.h"
namespace paddle {
namespace memory {
namespace allocation {
class ZeroSizeAllocation : public Allocation {
public:
explicit ZeroSizeAllocation(const platform::Place& p)
: Allocation(nullptr, 0, p) {}
};
class ZeroSizeAllocator : public ManagedAllocator {
public:
ZeroSizeAllocator(
const std::shared_ptr<ManagedAllocator>& underlying_allocator,
const platform::Place& p)
: underlying_allocator_(underlying_allocator), place_(p) {}
std::unique_ptr<Allocation> Allocate(size_t size, Attr attr) override;
std::shared_ptr<Allocation> AllocateShared(size_t size, Attr attr) override;
bool IsAllocThreadSafe() const override;
private:
std::shared_ptr<ManagedAllocator> underlying_allocator_;
const platform::Place& place_;
};
} // namespace allocation
} // namespace memory
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册