未验证 提交 2fac8abb 编写于 作者: L Leo Chen 提交者: GitHub

set device id before op run (#45994)

上级 925e84bf
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/phi/backends/device_manager.h"
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
false, false,
...@@ -595,16 +596,61 @@ void InterpreterCore::BuildSkipShareLoDInfo() { ...@@ -595,16 +596,61 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
} }
} }
inline void SetDeviceId(const platform::Place& place) {
// TODO(zhiqiu): reduce the cost
if (platform::is_gpu_place(place)) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CUDA support.",
place));
#else
auto dev_id = place.device;
platform::SetDeviceId(dev_id);
#endif
} else if (platform::is_xpu_place(place)) {
#ifndef PADDLE_WITH_XPU
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with XPU support.",
place));
#else
auto dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
#endif
} else if (platform::is_npu_place(place)) {
#ifndef PADDLE_WITH_ASCEND_CL
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with NPU support.",
place));
#else
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
phi::DeviceManager::SetDevice(place);
#endif
}
}
void InterpreterCore::RunInstruction(const Instruction& instr_node) { void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase(); auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope(); : var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
SetDeviceId(place);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place)) { if (platform::is_npu_place(place)) {
auto dev_id = place.device;
platform::SetNPUDeviceId(dev_id);
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
// values, but only through special `float_status` to checks whether // values, but only through special `float_status` to checks whether
// the operation is overflow. More about `float_status`, see: // the operation is overflow. More about `float_status`, see:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册