未验证 提交 f74e32e9 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix amp op list (#56176)

上级 a6a49855
...@@ -51,6 +51,7 @@ OpSupportedInfos(const std::string& place, ...@@ -51,6 +51,7 @@ OpSupportedInfos(const std::string& place,
{"GPU", &platform::is_gpu_place}, {"GPU", &platform::is_gpu_place},
{"CPU", &platform::is_cpu_place}, {"CPU", &platform::is_cpu_place},
{"XPU", &platform::is_xpu_place}, {"XPU", &platform::is_xpu_place},
{"CUSTOM_DEVICE", &platform::is_custom_place},
}; };
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
is_target_place.count(query_place), is_target_place.count(query_place),
...@@ -76,12 +77,6 @@ OpSupportedInfos(const std::string& place, ...@@ -76,12 +77,6 @@ OpSupportedInfos(const std::string& place,
} }
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto is_custom_place = [&](const std::string& place) {
return is_target_place.count(place) && place != "CPU" && place != "GPU" &&
place != "XPU";
};
#endif
auto phi_kernels = phi::KernelFactory::Instance().kernels(); auto phi_kernels = phi::KernelFactory::Instance().kernels();
for (auto& kernel_pair : phi_kernels) { for (auto& kernel_pair : phi_kernels) {
auto op_type = phi::TransToFluidOpName(kernel_pair.first); auto op_type = phi::TransToFluidOpName(kernel_pair.first);
...@@ -90,15 +85,6 @@ OpSupportedInfos(const std::string& place, ...@@ -90,15 +85,6 @@ OpSupportedInfos(const std::string& place,
all_ops.count(op_type) == 0) { all_ops.count(op_type) == 0) {
continue; continue;
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (info_pair.first.backend() == phi::Backend::CUSTOM) {
if (is_custom_place(query_place)) {
VLOG(4) << op_type << " " << supported_ops.size();
supported_ops.emplace(op_type);
}
continue;
}
#endif
if (is_target_place[query_place]( if (is_target_place[query_place](
phi::TransToPhiPlace(info_pair.first.backend(), false))) { phi::TransToPhiPlace(info_pair.first.backend(), false))) {
VLOG(8) << op_type << " " << supported_ops.size(); VLOG(8) << op_type << " " << supported_ops.size();
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import copy import copy
import logging import logging
import paddle
from paddle.amp.amp_lists import ( from paddle.amp.amp_lists import (
EXTRA_BLACK_LIST, EXTRA_BLACK_LIST,
FP16_BLACK_LIST, FP16_BLACK_LIST,
...@@ -94,6 +95,10 @@ def _get_sys_unsupported_list(dtype): ...@@ -94,6 +95,10 @@ def _get_sys_unsupported_list(dtype):
device = None device = None
if core.is_compiled_with_xpu(): if core.is_compiled_with_xpu():
device = 'XPU' device = 'XPU'
elif isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
):
device = 'CUSTOM_DEVICE'
else: else:
device = 'GPU' device = 'GPU'
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type) all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册