提交 0ac43217 编写于 作者: J JiayiFeng

check whether scalar condition var is on CPU before using

上级 01c5ca73
...@@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase { ...@@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase {
"numel should be 1, actual numel is %d", "numel should be 1, actual numel is %d",
ips[0]->numel()); ips[0]->numel());
} }
return ips[0]->data<bool>()[0]; bool res;
if (platform::is_gpu_place(ips[0]->place())) {
#ifdef PADDLE_WITH_CUDA
framework::LoDTensor cpu_tensor;
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else {
res = ips[0]->data<bool>()[0];
}
return res;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册