未验证 提交 9d40da31 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[AMP][Custom Device] add fp16 op detection for custom device (#56053)

* [AMP] add fp16 op detection for custom device

* resolve conflicts
上级 066097e8
......@@ -52,13 +52,16 @@ OpSupportedInfos(const std::string& place,
{"CPU", &platform::is_cpu_place},
{"XPU", &platform::is_xpu_place},
{"CUSTOM_DEVICE", &platform::is_custom_place},
#ifdef PADDLE_WITH_CUSTOM_DEVICE
{query_place, &platform::is_custom_place},
#endif
};
PADDLE_ENFORCE_NE(
is_target_place.count(query_place),
0,
platform::errors::InvalidArgument(
"The argument `place` should be 'GPU', 'CPU', 'XPU', but got '%s'.",
place));
PADDLE_ENFORCE_NE(is_target_place.count(query_place),
0,
platform::errors::InvalidArgument(
"The argument `place` should be 'GPU', 'CPU', 'XPU' or "
"other Custom Device, but got '%s'.",
place));
std::unordered_set<std::string> all_ops;
const auto& op_info = framework::OpInfoMap::Instance().map();
......
......@@ -98,7 +98,7 @@ def _get_sys_unsupported_list(dtype):
elif isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace
):
device = 'CUSTOM_DEVICE'
device = paddle.framework._current_expected_place().get_device_type()
else:
device = 'GPU'
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册