提交 dad7bdab 编写于 作者: Y Yu Yang

Add setDev

上级 7fd0d24e
......@@ -149,6 +149,7 @@ struct ScaleLossGradOpHandle : public OpHandle {
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
->stream();
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
VLOG(3) << "1";
PADDLE_ENFORCE(cudaGetLastError());
VLOG(3) << "2";
......@@ -163,7 +164,7 @@ struct ScaleLossGradOpHandle : public OpHandle {
void Wait(platform::DeviceContext *waited_dev) override {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
this->dev_ctx_.at(place_)->Wait();
dev_ctx_.at(place_)->Wait();
} else {
auto stream =
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册