未验证 提交 f74e32e9 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix amp op list (#56176)

上级 a6a49855
......@@ -51,6 +51,7 @@ OpSupportedInfos(const std::string& place,
{"GPU", &platform::is_gpu_place},
{"CPU", &platform::is_cpu_place},
{"XPU", &platform::is_xpu_place},
{"CUSTOM_DEVICE", &platform::is_custom_place},
};
PADDLE_ENFORCE_NE(
is_target_place.count(query_place),
......@@ -76,12 +77,6 @@ OpSupportedInfos(const std::string& place,
}
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto is_custom_place = [&](const std::string& place) {
return is_target_place.count(place) && place != "CPU" && place != "GPU" &&
place != "XPU";
};
#endif
auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto& kernel_pair : phi_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first);
......@@ -90,15 +85,6 @@ OpSupportedInfos(const std::string& place,
all_ops.count(op_type) == 0) {
continue;
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (info_pair.first.backend() == phi::Backend::CUSTOM) {
if (is_custom_place(query_place)) {
VLOG(4) << op_type << " " << supported_ops.size();
supported_ops.emplace(op_type);
}
continue;
}
#endif
if (is_target_place[query_place](
phi::TransToPhiPlace(info_pair.first.backend(), false))) {
VLOG(8) << op_type << " " << supported_ops.size();
......
......@@ -15,6 +15,7 @@
import copy
import logging
import paddle
from paddle.amp.amp_lists import (
EXTRA_BLACK_LIST,
FP16_BLACK_LIST,
......@@ -94,6 +95,10 @@ def _get_sys_unsupported_list(dtype):
device = None
if core.is_compiled_with_xpu():
device = 'XPU'
elif isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
):
device = 'CUSTOM_DEVICE'
else:
device = 'GPU'
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册