提交 524f6e9b 编写于 作者: Y Yu Yang

Refine code

上级 5cf395be
......@@ -2,7 +2,7 @@ cc_library(allocator SRCS allocator.cc DEPS place)
cc_library(cpu_allocator SRCS cpu_allocator.cc DEPS allocator)
cc_library(best_fit_allocator SRCS best_fit_allocator.cc DEPS allocator)
cc_library(locked_allocator SRCS locked_allocator.cc DEPS allocator)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator gpu_info)
nv_library(cuda_allocator SRCS cuda_allocator.cc DEPS allocator cuda_device_guard)
if (WITH_GPU)
nv_test(best_fit_allocator_test
......@@ -40,4 +40,5 @@ cc_library(allocator_facade SRCS allocator_facade.cc DEPS
locked_allocator
best_fit_allocator
naive_managed_allocator
aligned_allocator)
aligned_allocator
cuda_device_guard)
......@@ -21,6 +21,7 @@
#include "paddle/fluid/memory/allocation/cpu_allocator.h"
#include "paddle/fluid/memory/allocation/locked_allocator.h"
#include "paddle/fluid/memory/allocation/naive_managed_allocator.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
......@@ -45,6 +46,7 @@ class AllocatorFacadePrivate {
}
AllocatorFacadePrivate() {
std::cout << "Init Allocator Facade" << std::endl;
InitCPUAllocator();
InitCUDAAllocator();
}
......@@ -60,10 +62,10 @@ class AllocatorFacadePrivate {
void InitCUDAAllocator() {
#ifdef PADDLE_WITH_CUDA
for (int dev_id = 0; dev_id < platform::GetCUDADeviceCount(); ++dev_id) {
platform::CUDADeviceGuard guard(dev_id);
auto cuda_allocator =
NaiveManagedAllocator::Create(std::unique_ptr<Allocator>(
new CUDAAllocator(platform::CUDAPlace(dev_id))));
auto allocation = cuda_allocator->Allocate(platform::GpuMaxChunkSize());
auto allocator = NaiveManagedAllocator::Create(std::unique_ptr<Allocator>(
new LockedAllocator(std::unique_ptr<Allocator>(
......
......@@ -16,34 +16,14 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <string>
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace memory {
namespace allocation {
class CUDADeviceGuard {
public:
explicit CUDADeviceGuard(int dev_id) {
int prev_id = platform::GetCurrentDeviceId();
if (prev_id != dev_id) {
prev_id_ = prev_id;
platform::SetDeviceId(dev_id);
}
}
~CUDADeviceGuard() {
if (prev_id_ != -1) {
platform::SetDeviceId(prev_id_);
}
}
private:
int prev_id_{-1};
};
std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
CUDADeviceGuard guard(place_.device);
platform::CUDADeviceGuard guard(place_.device);
void* ptr;
auto status = cudaMalloc(&ptr, size);
if (UNLIKELY(status != cudaSuccess)) {
......@@ -57,6 +37,7 @@ std::unique_ptr<Allocation> CUDAAllocator::Allocate(size_t size, Attr attr) {
}
void CUDAAllocator::Free(Allocation* allocation) {
platform::CUDADeviceGuard guard(place_.device);
auto* cuda_allocation = dynamic_cast<CUDAAllocation*>(allocation);
PADDLE_ENFORCE_NOT_NULL(cuda_allocation);
PADDLE_ENFORCE_EQ(boost::get<platform::CUDAPlace>(cuda_allocation->place()),
......
......@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
TEST(selected_rows_functor, gpu_add) {
paddle::platform::CUDAPlace gpu_place(0);
......@@ -38,6 +38,7 @@ TEST(selected_rows_functor, gpu_add) {
{static_cast<int64_t>(rows1.size()), row_numel}),
gpu_place);
functor(ctx, in1_value, 1.0);
PADDLE_ENFORCE(cudaDeviceSynchronize());
std::vector<int64_t> rows2{0, 5, 7, 9};
std::unique_ptr<paddle::framework::SelectedRows> selected_rows2{
......
......@@ -73,3 +73,4 @@ cc_test(float16_test SRCS float16_test.cc DEPS lod_tensor)
IF(WITH_GPU)
nv_test(cuda_helper_test SRCS cuda_helper_test.cu)
ENDIF()
nv_library(cuda_device_guard SRCS cuda_device_guard.cc DEPS gpu_info)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/cuda_device_guard.h"
namespace paddle {
namespace platform {
// Even this source file does not contains any code, it is better to keep this
// source file for cmake dependency.
} // namespace platform
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace platform {
class CUDADeviceGuard {
public:
explicit inline CUDADeviceGuard(int dev_id) {
int prev_id = platform::GetCurrentDeviceId();
if (prev_id != dev_id) {
prev_id_ = prev_id;
platform::SetDeviceId(dev_id);
}
}
inline ~CUDADeviceGuard() {
if (prev_id_ != -1) {
platform::SetDeviceId(prev_id_);
}
}
CUDADeviceGuard(const CUDADeviceGuard& o) = delete;
CUDADeviceGuard& operator=(const CUDADeviceGuard& o) = delete;
private:
int prev_id_{-1};
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册