提交 fe10e86d 编写于 作者: Q qijun

fix gpu build error

上级 cb198fa7
......@@ -27,12 +27,12 @@ Executor::Executor(const std::vector<platform::Place>& places) {
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_[i].reset(new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else {
device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) {
#ifndef PADDLE_ONLY_CPU
device_contexts_[i].reset(new platform::CUDADeviceContext(
boost::get<platform::CPUPlace>(places[i])));
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
......@@ -40,6 +40,14 @@ Executor::Executor(const std::vector<platform::Place>& places) {
}
}
Executor::~Executor() {
for (auto& device_context : device_contexts_) {
if (device_context) {
delete device_context;
}
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope,
std::vector<Tensor>* outputs) {
// TODO(tonyyang-svail):
......@@ -59,6 +67,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope,
for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
std::cout << op->DebugString() << std::endl;
op->Run(*scope, *device);
}
......
......@@ -25,11 +25,11 @@ namespace framework {
class Executor {
public:
explicit Executor(const std::vector<platform::Place>& places);
~Executor() {}
~Executor();
void Run(const ProgramDesc&, Scope*, std::vector<Tensor>*);
private:
std::vector<std::unique_ptr<platform::DeviceContext>> device_contexts_;
std::vector<platform::DeviceContext*> device_contexts_;
};
} // namespace framework
......
......@@ -43,7 +43,7 @@ int GetCurrentDeviceId() {
}
void SetDeviceId(int id) {
PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count")
PADDLE_ENFORCE(id < GetDeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册