提交 fe10e86d 编写于 作者: Q qijun

fix gpu build error

上级 cb198fa7
...@@ -27,12 +27,12 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -27,12 +27,12 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.resize(places.size()); device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) { for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) { if (platform::is_cpu_place(places[i])) {
device_contexts_[i].reset(new platform::CPUDeviceContext( device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]))); boost::get<platform::CPUPlace>(places[i]));
} else { } else if (platform::is_gpu_place(places[i])) {
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
device_contexts_[i].reset(new platform::CUDADeviceContext( device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::CPUPlace>(places[i]))); boost::get<platform::GPUPlace>(places[i]));
#else #else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif #endif
...@@ -40,6 +40,14 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -40,6 +40,14 @@ Executor::Executor(const std::vector<platform::Place>& places) {
} }
} }
Executor::~Executor() {
for (auto& device_context : device_contexts_) {
if (device_context) {
delete device_context;
}
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, void Executor::Run(const ProgramDesc& pdesc, Scope* scope,
std::vector<Tensor>* outputs) { std::vector<Tensor>* outputs) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
...@@ -59,6 +67,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, ...@@ -59,6 +67,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope,
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
std::cout << op->DebugString() << std::endl;
op->Run(*scope, *device); op->Run(*scope, *device);
} }
......
...@@ -25,11 +25,11 @@ namespace framework { ...@@ -25,11 +25,11 @@ namespace framework {
class Executor { class Executor {
public: public:
explicit Executor(const std::vector<platform::Place>& places); explicit Executor(const std::vector<platform::Place>& places);
~Executor() {} ~Executor();
void Run(const ProgramDesc&, Scope*, std::vector<Tensor>*); void Run(const ProgramDesc&, Scope*, std::vector<Tensor>*);
private: private:
std::vector<std::unique_ptr<platform::DeviceContext>> device_contexts_; std::vector<platform::DeviceContext*> device_contexts_;
}; };
} // namespace framework } // namespace framework
......
...@@ -43,7 +43,7 @@ int GetCurrentDeviceId() { ...@@ -43,7 +43,7 @@ int GetCurrentDeviceId() {
} }
void SetDeviceId(int id) { void SetDeviceId(int id) {
PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count") PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id), PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId"); "cudaSetDevice failed in paddle::platform::SetDeviceId");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册