diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 59285dd6e61ea1a1d91671a1ef5c856c68897e6a..0a0170110de2aed4bcefa90cab6ceb464f4c1ee6 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -29,6 +29,7 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/phi/backends/device_manager.h" PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, false, @@ -596,16 +597,61 @@ void InterpreterCore::BuildSkipShareLoDInfo() { } } +inline void SetDeviceId(const platform::Place& place) { + // TODO(zhiqiu): reduce the cost + if (platform::is_gpu_place(place)) { +#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CUDA support.", + place)); +#else + auto dev_id = place.device; + platform::SetDeviceId(dev_id); +#endif + } else if (platform::is_xpu_place(place)) { +#ifndef PADDLE_WITH_XPU + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with XPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetXPUDeviceId(dev_id); +#endif + } else if (platform::is_npu_place(place)) { +#ifndef PADDLE_WITH_ASCEND_CL + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with NPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetNPUDeviceId(dev_id); +#endif + } else if (platform::is_custom_place(place)) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CustomDevice support.", + place)); +#else + phi::DeviceManager::SetDevice(place); +#endif + } +} + void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope(); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); + + SetDeviceId(place); + #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(place)) { - auto dev_id = place.device; - platform::SetNPUDeviceId(dev_id); // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable // values, but only through special `float_status` to checks whether // the operation is overflow. More about `float_status`, see: