提交 39505151 编写于 作者: Q qijun

remove device context manager

上级 6c4d1f55
...@@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD ...@@ -44,7 +44,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
cc_library(executor SRCS executor.cc DEPS op_registry device_context_manager scope framework_proto ${GLOB_OP_LIB}) cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto ${GLOB_OP_LIB})
if(WITH_GPU) if(WITH_GPU)
nv_test(executor_test SRCS executor_test.cc DEPS executor) nv_test(executor_test SRCS executor_test.cc DEPS executor)
else() else()
......
...@@ -25,14 +25,12 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -25,14 +25,12 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.resize(places.size()); device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) { for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) { if (platform::is_cpu_place(places[i])) {
device_contexts_[i] = platform::DeviceContextManager::Get() device_contexts_[i].reset(new platform::CPUDeviceContext(
->GetDeviceContext<platform::CPUPlace>( boost::get<platform::CPUPlace>(places[i])));
boost::get<platform::CPUPlace>(places[i]));
} else { } else {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
device_contexts_[i] = platform::DeviceContextManager::Get() device_contexts_[i].reset(new platform::CUDADeviceContext(
->GetDeviceContext<platform::GPUPlace>( boost::get<platform::CPUPlace>(places[i])));
boost::get<platform::GPUPlace>(places[i]));
#else #else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif #endif
...@@ -63,7 +61,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, ...@@ -63,7 +61,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope,
} }
// TODO(tonyyang-svail): need to test gpu device // TODO(tonyyang-svail): need to test gpu device
for (auto device_context : device_contexts_) { for (auto& device_context : device_contexts_) {
device_context->Wait(); device_context->Wait();
} }
} }
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context_manager.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -30,7 +29,7 @@ class Executor { ...@@ -30,7 +29,7 @@ class Executor {
void Run(const ProgramDesc&, Scope*, std::vector<Tensor>*); void Run(const ProgramDesc&, Scope*, std::vector<Tensor>*);
private: private:
std::vector<platform::DeviceContext*> device_contexts_; std::vector<std::unique_ptr<platform::DeviceContext>> device_contexts_;
}; };
} // namespace framework } // namespace framework
......
...@@ -23,7 +23,5 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator ...@@ -23,7 +23,5 @@ cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)
cc_library(device_context_manager SRCS device_context_manager.cc DEPS device_context)
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda) nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context) nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context_manager.h"
namespace paddle {
namespace platform {
DeviceContextManager::DeviceContextManager() {
#ifndef PADDLE_ONLY_CPU
device_count_ = GetDeviceCount();
cuda_contexts_.reserve(device_count_);
for (int i = 0; i < device_count_; i++) {
cuda_contexts_[i] = nullptr;
}
#endif
}
template <>
CPUDeviceContext* DeviceContextManager::GetDeviceContext<
CPUPlace, CPUDeviceContext>(const CPUPlace& place) {
if (!cpu_context_) {
cpu_context_ = new CPUDeviceContext(place);
}
return cpu_context_;
}
#ifndef PADDLE_ONLY_CPU
template <>
CUDADeviceContext* DeviceContextManager::GetDeviceContext<
GPUPlace, CUDADeviceContext>(const GPUPlace& place) {
int gpu_id = place.device;
PADDLE_ENFORCE(gpu_id < device_count_,
"GPU device id must less than device count");
SetDeviceId(gpu_id);
if (!cuda_contexts_[gpu_id]) {
cuda_contexts_[gpu_id] = new CUDADeviceContext(place);
}
return cuda_contexts_[gpu_id];
}
#endif
DeviceContextManager::~DeviceContextManager() {
if (cpu_context_) {
delete cpu_context_;
}
#ifndef PADDLE_ONLY_CPU
for (int i = 0; i < device_count_; i++) {
if (cuda_contexts_[i]) {
delete cuda_contexts_[i];
}
}
#endif
}
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
template <typename T>
struct Converter;
template <>
struct Converter<CPUPlace> {
using DeviceContextType = CPUDeviceContext;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct Converter<GPUPlace> {
using DeviceContextType = CUDADeviceContext;
};
#endif
class DeviceContextManager {
public:
DeviceContextManager();
~DeviceContextManager();
template <typename PlaceType, typename DeviceType = typename Converter<
PlaceType>::DeviceContextType>
DeviceType* GetDeviceContext(const PlaceType& place);
static DeviceContextManager* Get() {
static DeviceContextManager inst;
return &inst;
}
private:
CPUDeviceContext* cpu_context_;
#ifndef PADDLE_ONLY_CPU
int device_count_;
std::vector<CUDADeviceContext*> cuda_contexts_;
#endif
};
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册