From a896b32b4d926760058caa1826c74ceb296c1c6c Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Tue, 13 Sep 2022 21:02:44 +0800 Subject: [PATCH] set device id before op run (#45993) --- .../framework/new_executor/interpretercore.cc | 50 ++++++++++++++++++- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 59285dd6e61..0a0170110de 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -29,6 +29,7 @@ #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif +#include "paddle/phi/backends/device_manager.h" PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, false, @@ -596,16 +597,61 @@ void InterpreterCore::BuildSkipShareLoDInfo() { } } +inline void SetDeviceId(const platform::Place& place) { + // TODO(zhiqiu): reduce the cost + if (platform::is_gpu_place(place)) { +#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CUDA support.", + place)); +#else + auto dev_id = place.device; + platform::SetDeviceId(dev_id); +#endif + } else if (platform::is_xpu_place(place)) { +#ifndef PADDLE_WITH_XPU + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with XPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetXPUDeviceId(dev_id); +#endif + } else if (platform::is_npu_place(place)) { +#ifndef PADDLE_WITH_ASCEND_CL + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with NPU support.", + place)); +#else + auto dev_id = place.device; + platform::SetNPUDeviceId(dev_id); +#endif + } else if (platform::is_custom_place(place)) { +#ifndef PADDLE_WITH_CUSTOM_DEVICE + PADDLE_THROW(platform::errors::Unavailable( + "Cannot run operator on place %s, please recompile paddle or " + "reinstall Paddle with CustomDevice support.", + place)); +#else + phi::DeviceManager::SetDevice(place); +#endif + } +} + void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto* op = instr_node.OpBase(); auto place = instr_node.DeviceContext().GetPlace(); Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope(); VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); + + SetDeviceId(place); + #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(place)) { - auto dev_id = place.device; - platform::SetNPUDeviceId(dev_id); // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable // values, but only through special `float_status` to checks whether // the operation is overflow. More about `float_status`, see: -- GitLab