未验证 提交 9e045eeb 编写于 作者: D duanyanhui 提交者: GitHub

[CustomDevice] suport device_guard for custom device (#53808)

* suport device_guard for npu

* fix comment

* fix typo
上级 734dc448
......@@ -36,6 +36,9 @@
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
PADDLE_DEFINE_EXPORTED_bool(
new_executor_log_memory_stats,
false,
......@@ -400,6 +403,40 @@ void ApplyDeviceGuard(const OperatorBase* op_base,
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
bool is_custom_device_op = false;
for (auto dev_type : device_types) {
if (op_device.find(dev_type) != std::string::npos) {
is_custom_device_op = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_custom_device_op,
true,
phi::errors::Unimplemented(
"Unsupported current device %s with Paddle CustomDevice ",
op_device));
#else
VLOG(1) << string::Sprintf(
"Cannot use get_all_custom_device_type because you have installed"
"CPU/GPU version PaddlePaddle.\n"
"If you want to use get_all_custom_device_type, please try to "
"install CustomDevice version "
"PaddlePaddle by: pip install paddlepaddle\n");
#endif
if (op_base->SupportCustomDevice()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1) << "Op(" << op_base->Type()
<< ") has no Custom Place implementation. It "
"will be assigned to CPUPlace.";
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
......
......@@ -1365,6 +1365,42 @@ bool OperatorWithKernel::SupportXPU() const {
#endif
}
bool OperatorWithKernel::SupportCustomDevice() const {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
auto has_phi_kernel =
std::any_of(phi_kernels.begin(),
phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) {
return platform::is_custom_place(
phi::TransToPhiPlace(kern_pair.first.backend()));
});
if (has_phi_kernel) {
return true;
} else {
auto kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
return false;
} else {
auto& op_kernels = kernel_iter->second;
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[this](OpKernelMap::const_reference kern_pair) {
return platform::is_custom_place(kern_pair.first.place_);
});
}
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"should not call OperatorWithKernel::SupportCustomDevice() when not "
"compiled with "
"CustomDevice support."));
return false;
#endif
}
bool OperatorWithKernel::SupportsMKLDNN(const phi::DataType data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
......@@ -2088,7 +2124,12 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
// CPUKernel will be executed and a warning will be given at the same
// time.
expected_kernel_key.place_ = platform::CPUPlace();
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (SupportCustomDevice()) {
auto& dev_ctx = ctx.device_context();
expected_kernel_key.place_ = dev_ctx.GetPlace();
}
#endif
if (platform::is_cpu_place(expected_kernel_key.place_)) {
LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_
......
......@@ -288,6 +288,7 @@ class OperatorBase {
virtual bool SupportGPU() const { return false; }
virtual bool SupportXPU() const { return false; }
virtual bool SupportCustomDevice() const { return false; }
const std::string& Type() const { return type_; }
......@@ -748,6 +749,8 @@ class OperatorWithKernel : public OperatorBase {
bool SupportXPU() const override;
bool SupportCustomDevice() const override;
bool SupportsMKLDNN(phi::DataType data_type) const;
bool SupportsCUDNN(phi::DataType data_type) const;
......
......@@ -7431,9 +7431,12 @@ def device_guard(device=None):
device, index = device.split(':')
if device == 'cpu':
raise ValueError("Should not set device id for cpu.")
if device not in ['cpu', 'gpu', 'xpu', '', None]:
if (
device not in ['cpu', 'gpu', 'xpu', '', None]
and device not in core.get_all_custom_device_type()
):
raise ValueError(
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"The Attr(device) should be 'cpu', 'xpu', 'gpu' or custom device, and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device
)
if index:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册