未验证 提交 12542320 编写于 作者: H Huihuang Zheng 提交者: GitHub

Replace TemporaryAllocator by CUDADeviceContextAllocator (#18989)

TemporaryAllocator is a singleton used for allocating memory for Cudnn. Since it is a singleton, we can delete it for better performance in memory.

We replace TemporaryAllocator by CUDADeviceContextAllocator and CUDADeviceContextAllocation, which uses stream callback to delete the memory allocated for the stream to avoid singleton.

Also added data_feed_proto to operator to fix CI in CPU compilation
上级 0daa5c97
......@@ -389,7 +389,6 @@ function(cc_test_run TARGET_NAME)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
# No unit test should exceed 10 minutes.
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600)
......@@ -472,7 +471,6 @@ function(nv_test TARGET_NAME)
add_test(${TARGET_NAME} ${TARGET_NAME})
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_limit_of_tmp_allocation=4294967296) # 4G
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
endif()
endfunction(nv_test)
......@@ -725,7 +723,7 @@ function(py_test TARGET_NAME)
if(WITH_COVERAGE)
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G
FLAGS_cpu_deterministic=true
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
${PYTHON_EXECUTABLE} -m coverage run --branch -p ${py_test_SRCS} ${py_test_ARGS}
......@@ -733,7 +731,7 @@ function(py_test TARGET_NAME)
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
FLAGS_cpu_deterministic=true FLAGS_limit_of_tmp_allocation=4294967296 # 4G
FLAGS_cpu_deterministic=true
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_ENVS}
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......
......@@ -123,8 +123,8 @@ cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_co
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope
glog shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -103,16 +104,15 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
int dev_id = boost::get<platform::CUDAPlace>(place).device;
auto *nccl_ctxs = nccl_ctxs_->GetRunEnvNCCLCtx(run_order_, false);
auto &nccl_ctx = nccl_ctxs->at(dev_id);
auto *dev_ctx = nccl_ctxs->DevCtx(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(place, stream);
int encode_size = 2 * k * sizeof(int);
// dgc use ncclAllGather to get all the encoded data
// so the buffer need nranks.
int buf_size = nranks_ * encode_size;
auto tmp_ious_data = allocator.Allocate(buf_size);
auto tmp_ious_data = memory::Alloc(*dev_ctx, buf_size);
void *gather_buff = reinterpret_cast<void *>(tmp_ious_data->ptr());
VLOG(10) << "in_numel:" << in_numel << ", out_numel:" << out_numel
......
......@@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
......@@ -360,9 +361,7 @@ class ExecutionContext {
template <typename T, typename DevContext>
Tensor AllocateTmpTensor(const framework::DDim& dim,
const DevContext& dev_ctx) const {
auto tmp_allocation_ptr = platform::DeviceTemporaryAllocator::Instance()
.Get<DevContext>(dev_ctx)
.Allocate(product(dim) * sizeof(T));
auto tmp_allocation_ptr = memory::Alloc(dev_ctx, product(dim) * sizeof(T));
auto& deleter = tmp_allocation_ptr.get_deleter();
auto* allocation_ptr = tmp_allocation_ptr.release();
auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>(
......
......@@ -19,7 +19,6 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/temporary_allocator.h"
namespace paddle {
namespace framework {
......
add_subdirectory(detail)
add_subdirectory(allocation)
cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade profiler)
if (WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
else ()
set(MKLDNN_CTX_DEPS)
endif()
cc_library(malloc SRCS malloc.cc DEPS
place enforce allocator_facade profiler ${MKLDNN_CTX_DEPS})
cc_library(memcpy SRCS memcpy.cc DEPS place)
cc_library(memory
DEPS
malloc
memcpy)
if (WITH_GPU)
add_dependencies(malloc cuda_device_context_allocator_pool)
target_link_libraries(malloc cuda_device_context_allocator_pool)
nv_test(malloc_test
SRCS malloc_test.cu
DEPS device_context malloc)
endif()
#if (WITH_GPU)
# nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place memory)
#endif()
......@@ -6,8 +6,20 @@ cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(naive_best_fit_allocator SRCS naive_best_fit_allocator.cc DEPS allocator buddy_allocator profiler)
cc_test(buffered_allocator_test SRCS buffered_allocator_test.cc DEPS locked_allocator buffered_allocator cpu_allocator best_fit_allocator)
if (WITH_MKLDNN)
set(MKLDNN_CTX_DEPS mkldnn)
else ()
set(MKLDNN_CTX_DEPS)
endif()
if (WITH_GPU)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
nv_library(cuda_device_context_allocation SRCS cuda_device_context_allocation.cc
DEPS allocator enforce place ${MKLDNN_CTX_DEPS})
nv_library(cuda_device_context_allocator SRCS cuda_device_context_allocator.cc
DEPS allocator enforce place cuda_device_context_allocation ${MKLDNN_CTX_DEPS})
nv_library(cuda_device_context_allocator_pool SRCS cuda_device_context_allocator_pool.cc
DEPS allocator enforce place cuda_device_context_allocation cuda_device_context_allocator ${MKLDNN_CTX_DEPS})
endif()
cc_library(retry_allocator SRCS retry_allocator.cc DEPS allocator)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include <utility>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocation::CUDADeviceContextAllocation(
AllocationPtr allocation)
: Allocation(allocation->ptr(), allocation->size(), allocation->place()),
underlying_allocation_(std::move(allocation)) {}
CUDADeviceContextAllocation::~CUDADeviceContextAllocation() {
PADDLE_ENFORCE_NOT_NULL(
dev_ctx_, "Didn't set device context for CUDADeviceContextAllocation");
auto *p_allocation = underlying_allocation_.release();
VLOG(4) << "Adding callback to delete CUDADeviceContextAllocation at "
<< p_allocation;
dev_ctx_->AddStreamCallback([p_allocation] {
VLOG(4) << "Delete CUDADeviceContextAllocation at " << p_allocation;
AllocationDeleter()(p_allocation);
});
}
void CUDADeviceContextAllocation::SetCUDADeviceContext(
const platform::CUDADeviceContext *dev_ctx) {
dev_ctx_ = dev_ctx;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceContextAllocation : public Allocation {
public:
explicit CUDADeviceContextAllocation(AllocationPtr allocation);
~CUDADeviceContextAllocation();
void SetCUDADeviceContext(const platform::CUDADeviceContext *dev_ctx);
private:
AllocationPtr underlying_allocation_;
const platform::CUDADeviceContext *dev_ctx_{nullptr};
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocator::CUDADeviceContextAllocator(
const platform::CUDAPlace place, cudaStream_t default_stream)
: place_(place), default_stream_(default_stream) {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventCreate(&event_, cudaEventDisableTiming),
"Create event failed in CUDADeviceContextAllocator");
}
CUDADeviceContextAllocator::~CUDADeviceContextAllocator() {
if (event_) {
platform::CUDADeviceGuard guard(place_.device);
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventDestroy(event_),
"Destory event failed in CUDADeviceContextAllocator destroctor");
}
}
Allocation *CUDADeviceContextAllocator::AllocateImpl(size_t size) {
PADDLE_ENFORCE_NOT_NULL(
default_stream_,
"Didn't set default stream for CUDADeviceContextAllocator");
platform::CUDADeviceGuard guard(place_.device);
auto allocation =
new CUDADeviceContextAllocation(memory::Alloc(place_, size));
// Wait for the event on stream
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaEventRecord(event_, default_stream_),
"Failed to record event in CUDADeviceContextAllocator");
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaStreamWaitEvent(default_stream_, event_, 0),
"Failed to wait event in CUDADeviceContextAllocator");
return allocation;
}
void CUDADeviceContextAllocator::FreeImpl(Allocation *allocation) {
delete allocation;
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime.h>
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceContextAllocator : public Allocator {
public:
explicit CUDADeviceContextAllocator(platform::CUDAPlace place,
cudaStream_t default_stream);
~CUDADeviceContextAllocator();
protected:
Allocation *AllocateImpl(size_t size) override;
void FreeImpl(Allocation *allocation) override;
private:
platform::CUDAPlace place_;
cudaEvent_t event_{nullptr};
cudaStream_t default_stream_{nullptr};
};
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h"
#include <utility>
#include <vector>
#include "paddle/fluid/memory/allocation/cuda_device_context_allocation.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
CUDADeviceContextAllocatorPool &CUDADeviceContextAllocatorPool::Instance() {
static CUDADeviceContextAllocatorPool pool;
return pool;
}
AllocationPtr CUDADeviceContextAllocatorPool::Alloc(
const platform::CUDADeviceContext &dev_ctx, size_t size) {
auto iter =
allocators_.find(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()));
PADDLE_ENFORCE_EQ(iter != allocators_.end(), true,
"CUDADeviceContextAllocatorPool initialization error");
auto &allocator = iter->second;
AllocationPtr allocation = allocator->Allocate(size);
static_cast<CUDADeviceContextAllocation *>(allocation.get())
->SetCUDADeviceContext(&dev_ctx);
return allocation;
}
CUDADeviceContextAllocatorPool::CUDADeviceContextAllocatorPool() {
std::vector<int> devices = platform::GetSelectedDevices();
for (int i : devices) {
auto place = platform::CUDAPlace(i);
auto compute_stream =
platform::DeviceContextPool::Instance().GetByPlace(place)->stream();
auto allocator = std::shared_ptr<CUDADeviceContextAllocator>(
new CUDADeviceContextAllocator(place, compute_stream));
allocators_.insert(make_pair(place, allocator));
}
}
} // namespace allocation
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,56 +13,29 @@
// limitations under the License.
#pragma once
#include <condition_variable> // NOLINT
#include <deque>
#include <map>
#include <memory>
#include <mutex> // NOLINT
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
namespace memory {
namespace allocation {
/*! \brief the TemporaryAllocator is used to alloc the temporary allocation
* which used by CUDA's async operation.
*
* The TemporaryAllocator contains a temp_allocation_queue which
* is used to store the temporary allocations. The allocation, which is
* allocated by TemporaryAllocator, is a unique_ptr, and when it is not held
* by any variable, it will be pushed into the temp_allocation_queue.
*
* There is one opportunity to free the allocations of temp_allocation_queue:
* - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_tmp_allocation).
*
* */
class TemporaryAllocator : public memory::allocation::Allocator {
class CUDADeviceContextAllocatorPool {
public:
explicit TemporaryAllocator(platform::Place place);
void Release(const std::function<void()> &callback);
size_t TemporaryAllocationQueueSize();
bool IsAllocThreadSafe() const override;
void SetCallback(const std::function<void()> &callback);
protected:
void FreeImpl(memory::allocation::Allocation *allocation) override;
static CUDADeviceContextAllocatorPool &Instance();
memory::allocation::Allocation *AllocateImpl(size_t size) override;
AllocationPtr Alloc(const platform::CUDADeviceContext &dev_ctx, size_t size);
private:
platform::Place place_;
// When the allocation is not held by any variable, it should be placed
// to temp_mem_map immediately.
std::unique_ptr<std::multimap<size_t, memory::allocation::Allocation *>>
temp_mem_map_{nullptr};
std::mutex mtx_;
size_t wait_delete_mem_{0};
std::function<void()> callback_;
CUDADeviceContextAllocatorPool();
std::map<platform::CUDAPlace, std::shared_ptr<CUDADeviceContextAllocator>>
allocators_;
};
} // namespace platform
} // namespace allocation
} // namespace memory
} // namespace paddle
......@@ -17,17 +17,44 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator_pool.h"
#endif
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace memory {
std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
std::shared_ptr<Allocation> AllocShared(const platform::Place &place,
size_t size) {
return allocation::AllocatorFacade::Instance().AllocShared(place, size);
}
AllocationPtr Alloc(const platform::Place& place, size_t size) {
AllocationPtr Alloc(const platform::Place &place, size_t size) {
return allocation::AllocatorFacade::Instance().Alloc(place, size);
}
AllocationPtr Alloc(const platform::DeviceContext &dev_ctx, size_t size) {
auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
if (size == 0 || !platform::is_gpu_place(place)) {
return Alloc(place, size);
}
auto *default_dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto &desired_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
return Alloc(place, size);
} else {
return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
desired_dev_ctx, size);
}
#else
return Alloc(place, size);
#endif
}
} // namespace memory
} // namespace paddle
......@@ -18,7 +18,13 @@ limitations under the License. */
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace platform {
class DeviceContext;
} // platform
namespace memory {
using allocation::Allocation;
using allocation::Allocator;
using allocation::AllocationPtr;
......@@ -28,5 +34,7 @@ extern std::shared_ptr<Allocation> AllocShared(const platform::Place& place,
extern AllocationPtr Alloc(const platform::Place& place, size_t size);
extern AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size);
} // namespace memory
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include <cuda_runtime.h>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace memory {
const int NUM_STREAMS = 8;
const int N = 2;
const float DELTA = 1e-1;
using CudaDevCtxVec = std::vector<std::unique_ptr<platform::CUDADeviceContext>>;
__global__ void kernel(float *x, int n) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < n; i += blockDim.x * gridDim.x) {
x[i] = 3.14159 * i;
}
}
void CheckKernelOutput(float *x, int n) {
auto host_x = std::unique_ptr<float[]>(new float[n]);
for (int i = 0; i < n; ++i) {
EXPECT_TRUE(cudaSuccess == cudaMemcpy(host_x.get(), x, n * sizeof(float),
cudaMemcpyDeviceToHost));
EXPECT_GE(host_x[i] + DELTA, 3.14159f * i);
EXPECT_LE(host_x[i] - DELTA, 3.14159f * i);
}
}
void MultiStreamCompute(float **data, float **second_data,
const platform::CUDADeviceContext &ctx) {
// multi-streams
AllocationPtr allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*data = reinterpret_cast<float *>(allocation_ptr->ptr());
kernel<<<1, 64, 0, ctx.stream()>>>(*data, N);
// allocate and compute on same stream again
allocation_ptr = Alloc(ctx, N * sizeof(float));
EXPECT_GE(allocation_ptr->size(), N * sizeof(float));
*second_data = reinterpret_cast<float *>(allocation_ptr->ptr());
kernel<<<1, 64, 0, ctx.stream()>>>(*second_data, N);
}
TEST(Malloc, CUDADeviceContextMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
float *main_stream_data =
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
float *data[NUM_STREAMS];
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
// default stream
kernel<<<1, 64>>>(main_stream_data, N);
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
new platform::CUDADeviceContext(place)));
MultiStreamCompute(&data[i], &second_data[i], *dev_ctx[i]);
}
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
}
}
TEST(Malloc, CUDADeviceContextMultiThreadMultiStream) {
auto place = platform::CUDAPlace(0);
EXPECT_TRUE(cudaSuccess == cudaSetDevice(0));
AllocationPtr main_stream_alloc_ptr = Alloc(place, N * sizeof(float));
EXPECT_GE(main_stream_alloc_ptr->size(), N * sizeof(float));
float *main_stream_data =
reinterpret_cast<float *>(main_stream_alloc_ptr->ptr());
float *data[NUM_STREAMS];
float *second_data[NUM_STREAMS];
CudaDevCtxVec dev_ctx;
std::vector<std::thread> threads;
// default stream
kernel<<<1, 64>>>(main_stream_data, N);
main_stream_alloc_ptr.reset();
for (int i = 0; i < NUM_STREAMS; ++i) {
dev_ctx.push_back(std::unique_ptr<platform::CUDADeviceContext>(
new platform::CUDADeviceContext(place)));
threads.push_back(std::thread(MultiStreamCompute, &data[i], &second_data[i],
std::cref(*dev_ctx[i])));
}
for (int i = 0; i < NUM_STREAMS; ++i) {
threads[i].join();
}
EXPECT_TRUE(cudaSuccess == cudaDeviceSynchronize());
for (int i = 0; i < NUM_STREAMS; ++i) {
CheckKernelOutput(data[i], N);
CheckKernelOutput(second_data[i], N);
}
}
TEST(Malloc, AllocZero) {
auto place = platform::CUDAPlace(0);
AllocationPtr allocation_ptr = Alloc(place, 0);
EXPECT_GE(allocation_ptr->size(), 0);
}
} // namespace memory
} // namespace paddle
......@@ -28,6 +28,7 @@
#include <limits>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/deformable_psroi_pooling_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -231,10 +232,8 @@ class DeformablePSROIPoolCUDAKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
......@@ -499,10 +498,8 @@ class DeformablePSROIPoolGradCUDAKernel : public framework::OpKernel<T> {
}
}
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detection/box_coder_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -174,10 +174,8 @@ class BoxCoderCUDAKernel : public framework::OpKernel<T> {
int grid = (row * col + block - 1) / block;
auto& device_ctx = context.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(device_ctx);
int bytes = var_size * sizeof(float);
auto dev_var = allocator.Allocate(bytes);
auto dev_var = memory::Alloc(device_ctx, bytes);
float* dev_var_data = reinterpret_cast<float*>(dev_var->ptr());
auto cplace = platform::CPUPlace();
const auto gplace = boost::get<platform::CUDAPlace>(context.GetPlace());
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detection/yolo_box_op.h"
#include "paddle/fluid/operators/math/math_function.h"
......@@ -84,10 +85,8 @@ class YoloBoxOpCUDAKernel : public framework::OpKernel<T> {
int input_size = downsample_ratio * h;
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = sizeof(int) * anchors.size();
auto anchors_ptr = allocator.Allocate(sizeof(int) * anchors.size());
auto anchors_ptr = memory::Alloc(dev_ctx, sizeof(int) * anchors.size());
int* anchors_data = reinterpret_cast<int*>(anchors_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
const auto cplace = platform::CPUPlace();
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <vector>
#include "dgc/dgc.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
namespace paddle {
......@@ -112,9 +113,7 @@ class DGCOpKernel : public framework::OpKernel<T> {
framework::DDim{2 * k}, ctx.GetPlace());
int buf_size = paddle::communication::dgc::get_buffer_size(k);
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(
ctx.GetPlace(), dev_ctx.stream());
auto tmp_ious_data = allocator.Allocate(buf_size);
auto tmp_ious_data = memory::Alloc(dev_ctx, buf_size);
void* buf = reinterpret_cast<void*>(tmp_ious_data->ptr());
if (!paddle::communication::dgc::k_select(
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -184,9 +185,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
......@@ -251,9 +250,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
......@@ -142,9 +143,8 @@ void GPUGatherNd(const framework::ExecutionContext& context,
}
auto& dev_ctx = context.cuda_device_context();
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = input_dims_size * sizeof(int);
auto p_input_dims = allocator.Allocate(bytes);
auto p_input_dims = memory::Alloc(dev_ctx, bytes);
int* g_input_dims = reinterpret_cast<int*>(p_input_dims->ptr());
memory::Copy(gplace, g_input_dims, cplace, v_input_dims.data(), bytes,
ctx.stream());
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
......@@ -264,8 +265,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const T** dev_ins_data = nullptr;
if (!has_same_shape || in_num < 2 || in_num > 4) {
tmp_dev_ins_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_data.size() * sizeof(T*));
memory::Alloc(context, inputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_data.data()),
......@@ -292,8 +292,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
inputs_col.size() * sizeof(int));
memory::Alloc(context, inputs_col.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
static_cast<void*>(inputs_col.data()),
......@@ -356,8 +355,7 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
T** dev_out_gpu_data = nullptr;
if (!has_same_shape || o_num < 2 || o_num > 4) {
tmp_dev_outs_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_data.size() * sizeof(T*));
memory::Alloc(context, outputs_data.size() * sizeof(T*));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_outs_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_data.data()),
......@@ -384,8 +382,9 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
}
} else {
auto tmp_dev_ins_col_data =
platform::DeviceTemporaryAllocator::Instance().Get(context).Allocate(
outputs_cols.size() * sizeof(int));
memory::Alloc(context,
outputs_cols.size() * sizeof(int));
memory::Copy(boost::get<platform::CUDAPlace>(context.GetPlace()),
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(),
reinterpret_cast<void*>(outputs_cols.data()),
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/mean_iou_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -116,9 +117,7 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);
// Temporary memory
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto tmp_ious_data = allocator.Allocate(num_classes * sizeof(float));
auto tmp_ious_data = memory::Alloc(dev_ctx, num_classes * sizeof(float));
float* ious_data = static_cast<float*>(tmp_ious_data->ptr());
// Init out_wrong, out_correct and out_mean_iou
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/roi_align_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -272,10 +272,8 @@ class GPUROIAlignOpKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
......@@ -322,9 +320,8 @@ class GPUROIAlignGradOpKernel : public framework::OpKernel<T> {
}
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto roi_ptr = allocator.Allocate(roi_batch_id_list.numel() * sizeof(int));
auto roi_ptr =
memory::Alloc(dev_ctx, roi_batch_id_list.numel() * sizeof(int));
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
int bytes = roi_batch_id_list.numel() * sizeof(int);
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/roi_pool_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -170,10 +170,8 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
......@@ -221,10 +219,8 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = ctx.cuda_device_context();
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = roi_batch_id_list.numel() * sizeof(int);
auto roi_ptr = allocator.Allocate(bytes);
auto roi_ptr = memory::Alloc(dev_ctx, bytes);
int* roi_id_data = reinterpret_cast<int*>(roi_ptr->ptr());
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(gplace, roi_id_data, cplace, roi_batch_id_data, bytes,
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "math/math_function.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/place.h"
......@@ -170,9 +171,8 @@ void GPUScatterNdAdd(const framework::ExecutionContext& context,
v_output_dims[i] = static_cast<int>(output_dims[i]);
}
auto& dev_ctx = context.cuda_device_context();
auto& allocator = platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
int bytes = output_dims_size * sizeof(int);
auto output_dims_ptr = allocator.Allocate(bytes);
auto output_dims_ptr = memory::Alloc(dev_ctx, bytes);
int* g_output_dims = reinterpret_cast<int*>(output_dims_ptr->ptr());
memory::Copy(gplace, g_output_dims, cplace, v_output_dims.data(), bytes,
ctx.stream());
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -116,9 +117,7 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
bool normalize = context.Attr<bool>("normalize");
// Temporary memory
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cnt_ptr = allocator.Allocate(Labels->numel() * sizeof(T));
auto cnt_ptr = memory::Alloc(dev_ctx, Labels->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
int limit = Out->numel();
......@@ -127,7 +126,7 @@ class GPUSigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
GPUSigmoidForward<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
X->data<T>(), Labels->data<T>(), ignore_index, limit, out_data, counts);
if (normalize) {
auto norm_ptr = allocator.Allocate(sizeof(T));
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
......@@ -152,9 +151,7 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel
auto &dev_ctx = context.cuda_device_context();
// Temporary memory
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cnt_ptr = allocator.Allocate(X->numel() * sizeof(T));
auto cnt_ptr = memory::Alloc(dev_ctx, X->numel() * sizeof(T));
T *counts = reinterpret_cast<T *>(cnt_ptr->ptr());
int limit = dX->numel();
......@@ -165,7 +162,7 @@ class GPUSigmoidCrossEntropyWithLogitsGradKernel
dx_data, counts);
bool normalize = context.Attr<bool>("normalize");
if (normalize) {
auto norm_ptr = allocator.Allocate(sizeof(T));
auto norm_ptr = memory::Alloc(dev_ctx, sizeof(T));
T *norm = reinterpret_cast<T *>(norm_ptr->ptr());
Sum<T, kNumCUDAThreads><<<1, kNumCUDAThreads, 0, dev_ctx.stream()>>>(
counts, limit, static_cast<T>(1e-5), norm);
......
......@@ -11,6 +11,7 @@ limitations under the License. */
#include <paddle/fluid/platform/device_context.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/float16.h"
......@@ -197,8 +198,7 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
}
if (!sr_in_out_data.empty()) {
auto tmp_sr_in_out_array =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
sr_in_out_data.size() * sizeof(T *));
memory::Alloc(dev_ctx, sr_in_out_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_sr_in_out_array->ptr(), platform::CPUPlace(),
......@@ -216,9 +216,7 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
}
// if indata not null, merge into one kernel call.
if (!in_data.empty()) {
auto tmp_in_array =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
in_data.size() * sizeof(T *));
auto tmp_in_array = memory::Alloc(dev_ctx, in_data.size() * sizeof(T *));
memory::Copy(boost::get<platform::CUDAPlace>(dev_ctx.GetPlace()),
tmp_in_array->ptr(), platform::CPUPlace(),
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
......@@ -149,12 +150,10 @@ class SyncBatchNormKernel : public framework::OpKernel<T> {
mean_data = est_mean->data<T>();
var_data = est_var->data<T>();
} else {
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
// x, x^2, 1, here 1 is used to calc device num
// device num also can be got from platform::DeviceContextPool
const int bytes = (C * 2 + 1) * sizeof(T);
alloc_ptr = allocator.Allocate(bytes);
alloc_ptr = memory::Alloc(dev_ctx, bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr());
const int threads = 256;
......@@ -373,10 +372,8 @@ class SyncBatchNormGradKernel : public framework::OpKernel<T> {
const T *saved_mean = ctx.Input<Tensor>("SavedMean")->data<T>();
const T *saved_inv_var = ctx.Input<Tensor>("SavedVariance")->data<T>();
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
const int bytes = (C * 2 + 1) * sizeof(T);
auto alloc_ptr = allocator.Allocate(bytes);
auto alloc_ptr = memory::Alloc(dev_ctx, bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr());
const int threads = 256;
......
......@@ -61,8 +61,6 @@ ELSE()
set(MKLDNN_CTX_DEPS)
ENDIF()
cc_library(temp_allocator SRCS temporary_allocator.cc DEPS allocator_facade)
nv_library(stream_callback_manager SRCS stream_callback_manager.cc DEPS simple_threadpool enforce)
IF(WITH_GPU)
set(STREAM_CALLBACK_DEPS stream_callback_manager)
......@@ -74,7 +72,7 @@ ENDIF()
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool malloc ${STREAM_CALLBACK_DEPS}
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
temp_allocator ${dgc_deps})
${dgc_deps})
if (WITH_DISTRIBUTE)
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto device_context enforce)
......@@ -117,12 +115,6 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
if(WITH_GPU)
nv_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator)
else()
cc_test(temporal_allocator_test SRCS temporary_allocator_test.cc DEPS temp_allocator tensor operator)
endif()
if(NOT APPLE AND NOT WIN32)
cc_library(device_code SRCS device_code.cc DEPS device_context)
if(WITH_GPU)
......
......@@ -89,47 +89,6 @@ DeviceContextPool::DeviceContextPool(
}
}
DeviceTemporaryAllocator* DeviceTemporaryAllocator::allocators = nullptr;
#ifdef PADDLE_WITH_CUDA
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place, const cudaStream_t& stream) {
PADDLE_ENFORCE(platform::is_gpu_place(place));
auto place_stream = std::make_pair(place, stream);
std::unique_lock<std::mutex> lock(mtx_);
auto it = device_allocator_.find(place_stream);
if (it == device_allocator_.end()) {
auto tmp_allocator = new TemporaryAllocator(place);
tmp_allocator->SetCallback([stream]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
});
device_allocator_[place_stream].reset(tmp_allocator);
return *tmp_allocator;
} else {
return *it->second;
}
}
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CUDADeviceContext& dev_ctx) {
return Get(dev_ctx.GetPlace(), dev_ctx.stream());
}
#endif
template <>
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::CPUDeviceContext& dev_ctx) {
return cpu_allocator_;
}
platform::TemporaryAllocator& DeviceTemporaryAllocator::Get(
const platform::Place& place) {
PADDLE_ENFORCE(platform::is_cpu_place(place), "You should pass CPUPlace");
return cpu_allocator_;
}
CPUDeviceContext::CPUDeviceContext() {
eigen_device_.reset(new Eigen::DefaultDevice());
}
......@@ -169,7 +128,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
if (UNLIKELY(num_bytes == 0)) {
return nullptr;
}
auto buf = paddle::memory::Alloc(place_, num_bytes);
auto buf = memory::Alloc(place_, num_bytes);
VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
<< " requested " << num_bytes;
void* retv = buf->ptr();
{
std::lock_guard<std::mutex> lock(mtx_);
......@@ -197,7 +158,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
char* scratch =
static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
PADDLE_ENFORCE(
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
}
return semaphore_;
......@@ -213,36 +174,12 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
};
CudnnHolder::CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
: workspace_(nullptr), stream_(stream), place_(place) {
PADDLE_ENFORCE(cudaSetDevice(place_.device));
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
}
CudnnHolder::~CudnnHolder() {
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
void CudnnHolder::ReallocateWorkspace(size_t required_workspace_len) {
if (required_workspace_len <= WorkspaceSize()) {
return;
}
if (workspace_ != nullptr) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
workspace_.reset();
}
workspace_ = paddle::memory::Alloc(place_, required_workspace_len);
}
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: place_(place), cudnn_holder_(nullptr) {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
CUDADeviceGuard guard(place_.device);
compute_capability_ = GetCUDAComputeCapability(place_.device);
multi_process_ = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
......@@ -302,6 +239,14 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
<< "Please recompile or reinstall Paddle with compatible CUDNN "
"version.";
}
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnCreate(&cudnn_handle_),
"Failed to create Cudnn handle in DeviceContext");
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnSetStream(cudnn_handle_, stream_),
"Failed to set stream for Cudnn handle in DeviceContext");
} else {
cudnn_handle_ = nullptr;
}
}
......@@ -316,10 +261,14 @@ CUDADeviceContext::~CUDADeviceContext() {
cublas_tensor_core_handle_.reset();
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamDestroy(stream_));
if (cudnn_handle_) {
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_),
"Failed to destory Cudnn handle");
}
#if !defined(_WIN32)
if (nccl_comm_) {
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
}
#endif
}
......@@ -327,21 +276,17 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const {
auto& allocator =
DeviceTemporaryAllocator::Instance().Get<CUDADeviceContext>(*this);
allocator.Release([this]() {
cudaError_t e_sync = cudaStreamSynchronize(stream_);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno:" << e_sync;
}
cudaError_t e_sync = cudaStreamSynchronize(stream_);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync)
<< " errno: " << e_sync;
}
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno:" << e_get;
}
});
cudaError_t e_get = cudaGetLastError();
if (e_get != 0) {
LOG(FATAL) << "cudaGetLastError " << cudaGetErrorString(e_get)
<< " errno: " << e_get;
}
}
int CUDADeviceContext::GetComputeCapability() const {
......@@ -360,21 +305,10 @@ bool CUDADeviceContext::tensor_core_available() const {
return cublas_tensor_core_handle_ != nullptr;
}
CudnnHolder* CUDADeviceContext::cudnn_holder() const {
std::call_once(init_cudnn_, [&]() {
if (dynload::HasCUDNN()) {
cudnn_holder_.reset(new CudnnHolder(&stream_, place_));
}
});
return cudnn_holder_.get();
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
return cudnn_holder()->cudnn_handle();
}
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(cudnn_holder());
return CudnnWorkspaceHandle(*this);
}
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
......@@ -45,71 +44,6 @@ limitations under the License. */
namespace paddle {
namespace platform {
/*! \brief device temporary allocator singleton.
*
* Some operator needs temporary memory during computation, for example,
* conv_gemm, which needs use col to store the result of im2col. If we
* create a stack memory which is used by CUDA Kernel, before the
* Computation(...) returns, we should add ctx->Wait(), because the
* execution of CUDA is async, if there doesn't have ctx->Wait(),
* the temporary memory will be released before the CUDA Kernel uses
* it.
*
* DeviceTemporaryAllocator is a singleton, which contains a
* `TemporaryAllocator` for each <Place, Stream>. And the TemporaryAllocator
* contains a temp_allocation_queue which is used to store the temporary
* allocations. The allocation, which is allocated by TemporaryAllocator,
* is a unique_ptr, and when it is not held by any variable, it will be
* pushed into the temp_allocation_queue. There are two opportunities to free
* the allocations of temp_allocation_queue:
* - when the Stream calls cudaStreamSynchronize;
* - when the allocation size of opportunities exceeds a certain threshold
* (defined by FLAGS_limit_of_tmp_allocation).
*
* */
class DeviceTemporaryAllocator {
public:
static DeviceTemporaryAllocator& Instance() {
PADDLE_ENFORCE_NOT_NULL(allocators,
"Need to Create DeviceTemporaryAllocator first!");
return *allocators;
}
static DeviceTemporaryAllocator& Init() {
if (allocators == nullptr) {
allocators = new DeviceTemporaryAllocator();
}
return *allocators;
}
/*! \brief Return handle of single temporary allocator. */
#ifdef PADDLE_WITH_CUDA
platform::TemporaryAllocator& Get(const platform::Place& place,
const cudaStream_t& stream);
#endif
template <typename DeviceContext>
platform::TemporaryAllocator& Get(const DeviceContext& dev_ctx);
platform::TemporaryAllocator& Get(const platform::Place& place);
private:
DeviceTemporaryAllocator() : cpu_allocator_(platform::CPUPlace()) {}
static DeviceTemporaryAllocator* allocators;
platform::TemporaryAllocator cpu_allocator_;
#ifdef PADDLE_WITH_CUDA
std::map<std::pair<platform::Place, cudaStream_t>,
std::unique_ptr<platform::TemporaryAllocator>>
device_allocator_;
#endif
std::mutex mtx_;
DISABLE_COPY_AND_ASSIGN(DeviceTemporaryAllocator);
};
class DeviceContext {
public:
virtual ~DeviceContext() {}
......@@ -143,102 +77,7 @@ struct DefaultDeviceContextType<platform::CPUPlace> {
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
class CudnnHolder {
public:
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place);
~CudnnHolder();
cudnnHandle_t cudnn_handle() const { return cudnn_handle_; }
private:
friend class CudnnWorkspaceHandle;
void ReallocateWorkspace(size_t required_workspace_len);
template <typename Callback>
void RunFuncImpl(Callback&& cudnn_func, size_t required_workspace_len) {
if (required_workspace_len > WorkspaceSize()) {
ReallocateWorkspace(required_workspace_len);
}
VLOG(2) << "Cudnn workspace size: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
cudnn_func(WorkspacePtr());
}
/*! \brief Reset workspace thus release the memory */
inline void ResetWorkspace() {
if (workspace_) {
// Maybe someone is using the current workspace
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(*stream_));
workspace_ = nullptr;
}
}
inline void* WorkspacePtr() {
if (workspace_) {
return workspace_->ptr();
} else {
return nullptr;
}
}
inline size_t WorkspaceSize() {
if (workspace_) {
return workspace_->size();
} else {
return 0;
}
}
std::mutex& Mutex() { return mtx_; }
cudnnHandle_t cudnn_handle_;
memory::AllocationPtr workspace_;
const cudaStream_t* stream_; // not owned;
const CUDAPlace place_;
std::mutex mtx_;
};
class CudnnWorkspaceHandle {
public:
/*! \brief The lock would not be acquired when constructor calls.
* The lock would be acquired when RunFunc() is called first time. */
inline explicit CudnnWorkspaceHandle(CudnnHolder* holder) : holder_(holder) {}
/*! \brief Thread which call RunFunc() would acquire the lock first
* before invoking cudnn functions. */
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
}
/*! \brief Thread which call RunFuncSync() would acquire the lock first
* before invoking cudnn function and release gpu memory after running
* the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
template <typename Callback>
inline void RunFuncSync(Callback&& cudnn_func,
size_t required_workspace_len) {
if (!guard_) {
guard_.reset(new std::lock_guard<std::mutex>(holder_->Mutex()));
}
holder_->RunFuncImpl(std::forward<Callback>(cudnn_func),
required_workspace_len);
holder_->ResetWorkspace();
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
CudnnHolder* holder_; // not own
std::unique_ptr<std::lock_guard<std::mutex>> guard_;
};
class CudnnWorkspaceHandle;
class CUDADeviceContext : public DeviceContext {
public:
......@@ -323,9 +162,8 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::unique_ptr<CudnnHolder> cudnn_holder_;
cudaStream_t stream_;
cudnnHandle_t cudnn_handle_;
std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
......@@ -346,11 +184,60 @@ class CUDADeviceContext : public DeviceContext {
// StreamCallbackManager is thread-safe
std::unique_ptr<StreamCallbackManager> callback_manager_;
CudnnHolder* cudnn_holder() const;
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
class CudnnWorkspaceHandle {
public:
inline explicit CudnnWorkspaceHandle(const CUDADeviceContext& dev_ctx)
: device_context_(dev_ctx) {}
template <typename Callback>
inline void RunFunc(Callback&& cudnn_func, size_t required_workspace_bytes) {
if (required_workspace_bytes > WorkspaceSize()) {
ReallocWorkspace(required_workspace_bytes);
}
VLOG(2) << "Cudnn workspace size at RunFunc: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
template <typename Callback>
inline void RunFuncSync(Callback&& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
inline void ReallocWorkspace(size_t required_workspace_bytes) {
if (required_workspace_bytes <= WorkspaceSize()) {
return;
}
allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}
inline void ResetWorkspace() { allocation_ = nullptr; }
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
return 0;
}
return allocation_->size();
}
CudnnWorkspaceHandle(CudnnWorkspaceHandle&&) = default;
CudnnWorkspaceHandle& operator=(CudnnWorkspaceHandle&&) = delete;
private:
memory::allocation::AllocationPtr allocation_;
const CUDADeviceContext& device_context_;
};
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
using TYPE = CUDADeviceContext;
......
......@@ -153,7 +153,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
}
places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
platform::DeviceTemporaryAllocator::Init();
#ifndef PADDLE_WITH_MKLDNN
platform::SetNumThreads(FLAGS_paddle_num_threads);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/temporary_allocator.h"
#include <memory>
#include "paddle/fluid/memory/allocation/allocator_facade.h"
DEFINE_int64(limit_of_tmp_allocation, -1,
"The up limit of temporary_allocation size.");
DEFINE_double(times_excess_than_required_tmp_allocation, 2,
"times_excess_than_required_tmp_allocation indicates the "
"max size the TemporaryAllocator can return. For example, "
"if the required memory size is N, and "
"times_excess_than_required_tmp_allocation is 2.0, "
"the TemporaryAllocator will return the available allocation "
"that the range of size is N ~ 2*N.");
namespace paddle {
namespace platform {
namespace alloc = memory::allocation;
TemporaryAllocator::TemporaryAllocator(platform::Place place) : place_(place) {
temp_mem_map_.reset(new std::multimap<size_t, alloc::Allocation *>());
}
bool TemporaryAllocator::IsAllocThreadSafe() const { return true; }
void TemporaryAllocator::Release(const std::function<void()> &callback) {
std::unique_ptr<std::multimap<size_t, alloc::Allocation *>> t_allocations;
{
std::unique_lock<std::mutex> lock(mtx_);
callback();
t_allocations.swap(temp_mem_map_);
temp_mem_map_.reset(new std::multimap<size_t, alloc::Allocation *>());
wait_delete_mem_ = 0;
}
alloc::AllocationDeleter deleter;
for (auto tmp : *t_allocations) {
VLOG(10) << "Delete temporary allocation " << tmp.second->ptr()
<< " size: " << tmp.second->size();
deleter(tmp.second);
}
}
void TemporaryAllocator::FreeImpl(alloc::Allocation *temp_allocation) {
if (platform::is_gpu_place(temp_allocation->place())) {
PADDLE_ENFORCE(platform::is_same_place(temp_allocation->place(), place_),
"The place should be the same.");
size_t wait_delete_mem = 0;
{
std::unique_lock<std::mutex> lock(mtx_);
temp_mem_map_->emplace(temp_allocation->size(), temp_allocation);
wait_delete_mem_ += temp_allocation->size();
wait_delete_mem = wait_delete_mem_;
VLOG(10) << "Move temporary allocation: " << temp_allocation->ptr()
<< " to delete queue: " << temp_allocation->size() << "; "
<< "wait_delete_mem: " << wait_delete_mem;
}
if (FLAGS_limit_of_tmp_allocation >= 0 &&
wait_delete_mem >= static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized.");
Release(callback_);
}
return;
}
VLOG(10) << "Delete temporary allocation " << temp_allocation->ptr()
<< " size: " << temp_allocation->size();
alloc::AllocationDeleter()(temp_allocation);
}
size_t TemporaryAllocator::TemporaryAllocationQueueSize() {
std::unique_lock<std::mutex> lock(mtx_);
return temp_mem_map_ ? temp_mem_map_->size() : 0;
}
void TemporaryAllocator::SetCallback(const std::function<void()> &callback) {
callback_ = callback;
}
alloc::Allocation *TemporaryAllocator::AllocateImpl(size_t size) {
{
// Find available allocation in temp_mem_map.
std::unique_lock<std::mutex> lock(mtx_);
if (temp_mem_map_->size()) {
auto it = temp_mem_map_->lower_bound(size);
// FIXME(zcd): Not sure the best value of excess fraction.
if (it != temp_mem_map_->end() &&
it->first <
static_cast<size_t>(
size * FLAGS_times_excess_than_required_tmp_allocation)) {
auto tmp_ptr = it->second;
temp_mem_map_->erase(it);
wait_delete_mem_ -= tmp_ptr->size();
VLOG(10) << "Reuse temporary allocation: " << tmp_ptr->ptr() << ": "
<< tmp_ptr->size();
return tmp_ptr;
}
}
}
// If not find the the available allocation, get allocation from
// AllocatorFacadeInstance.
auto temp_mem = alloc::AllocatorFacade::Instance().Alloc(place_, size);
VLOG(10) << "Alloc temporary allocation: " << temp_mem->ptr() << ": " << size;
return temp_mem.release();
}
} // namespace platform
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/temporary_allocator.h"
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
DECLARE_int64(limit_of_tmp_allocation);
DECLARE_double(times_excess_than_required_tmp_allocation);
namespace paddle {
namespace platform {
class DummyOp : public framework::OperatorBase {
public:
DummyOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {}
};
TEST(temporary_allocator, test_base_function) {
platform::CPUPlace cpu_place;
TemporaryAllocator alloc(cpu_place);
alloc.Allocate(100);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
auto allocation = gpu_alloc.Allocate(101);
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
{
auto allocation = gpu_alloc.Allocate(102);
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
gpu_alloc.Release([]() {});
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
#endif
}
TEST(temporary_allocator, test_flags_function) {
#ifdef PADDLE_WITH_CUDA
const int64_t limit = FLAGS_limit_of_tmp_allocation;
FLAGS_limit_of_tmp_allocation = 10;
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
auto stream = dev_ctx->stream();
bool deleted = false;
gpu_alloc.SetCallback([stream, &deleted]() {
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
PADDLE_ENFORCE(cudaGetLastError());
deleted = true;
});
{ gpu_alloc.Allocate(100); }
PADDLE_ENFORCE(deleted);
FLAGS_limit_of_tmp_allocation = limit;
#endif
}
TEST(temporary_allocator, test_reuse_tmp_allocation) {
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
gpu_alloc.SetCallback([]() {});
void* tmp_allocation_ptr1 = nullptr;
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 = gpu_alloc.Allocate(200);
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(200);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
auto tmp_allocation3 = gpu_alloc.Allocate(200);
void* tmp_allocation_ptr3 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr3);
#endif
}
TEST(temporary_allocator, test_times_excess_than_required_tmp_allocation) {
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace gpu_place(0);
TemporaryAllocator gpu_alloc(gpu_place);
gpu_alloc.SetCallback([]() {});
double excess_fraction = FLAGS_times_excess_than_required_tmp_allocation;
void* tmp_allocation_ptr1 = nullptr;
{
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
auto tmp_allocation1 =
gpu_alloc.Allocate(static_cast<size_t>(200 * excess_fraction - 1));
tmp_allocation_ptr1 = tmp_allocation1->ptr();
}
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 1);
auto tmp_allocation2 = gpu_alloc.Allocate(200 * excess_fraction - 10);
void* tmp_allocation_ptr2 = tmp_allocation2->ptr();
PADDLE_ENFORCE_EQ(gpu_alloc.TemporaryAllocationQueueSize(), 0);
PADDLE_ENFORCE_EQ(tmp_allocation_ptr1, tmp_allocation_ptr2);
#endif
}
TEST(temporary_allocator, create_tensor_with_allocationptr) {
framework::VariableNameMap dummy_vars;
framework::AttributeMap dummy_attrs;
DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs);
framework::Scope scope;
framework::VariableValueMap vars;
framework::RuntimeContext run_ctx(vars, vars);
size_t memory_size = 300;
{
platform::CPUPlace cpu_place;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CPUDeviceContext>(
framework::make_ddim({numel}), *dev_ctx);
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
#ifdef PADDLE_WITH_CUDA
{
platform::CUDAPlace gpu_place(0);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CUDADeviceContext>(
framework::make_ddim({numel}), *dev_ctx);
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
}
#endif
}
TEST(temporary_allocator, create_tensor_with_allocationptr2) {
framework::VariableNameMap dummy_vars;
framework::AttributeMap dummy_attrs;
DummyOp op("dummy", dummy_vars, dummy_vars, dummy_attrs);
framework::Scope scope;
framework::VariableValueMap vars;
framework::RuntimeContext run_ctx(vars, vars);
size_t memory_size = 400;
{
platform::CPUPlace cpu_place;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CPUDeviceContext*>(pool.Get(cpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor;
{
framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CPUDeviceContext>(
framework::make_ddim({numel}), *dev_ctx);
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
out_side_tensor.ShareDataWith(tensor);
}
PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel);
}
#ifdef PADDLE_WITH_CUDA
{
platform::CUDAPlace gpu_place(0);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx =
static_cast<platform::CUDADeviceContext*>(pool.Get(gpu_place));
framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr);
size_t memory_size = 500;
int numel = memory_size / sizeof(float);
framework::Tensor out_side_tensor;
{
framework::Tensor tensor =
ctx.AllocateTmpTensor<float, platform::CUDADeviceContext>(
framework::make_ddim({numel}), *dev_ctx);
PADDLE_ENFORCE_EQ(tensor.numel(), numel);
out_side_tensor.ShareDataWith(tensor);
}
PADDLE_ENFORCE_EQ(out_side_tensor.numel(), numel);
}
#endif
}
} // namespace platform
} // namespace paddle
......@@ -57,6 +57,7 @@ int main(int argc, char** argv) {
envs.push_back("initial_cpu_memory_in_mb");
envs.push_back("allocator_strategy");
undefok.push_back("use_pinned_memory");
undefok.push_back("use_mkldnn");
undefok.push_back("initial_cpu_memory_in_mb");
#endif
......
......@@ -202,8 +202,6 @@ def __bootstrap__():
'reallocate_gpu_memory_in_mb', 'cudnn_deterministic',
'enable_cublas_tensor_op_math', 'conv_workspace_size_limit',
'cudnn_exhaustive_search', 'selected_gpus', 'sync_nccl_allreduce',
'limit_of_tmp_allocation',
'times_excess_than_required_tmp_allocation',
'cudnn_batchnorm_spatial_persistent', 'gpu_allocator_retry_time'
]
core.init_gflags([sys.argv[0]] +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册