未验证 提交 9d40da31 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[AMP][Custom Device] add fp16 op detection for custom device (#56053)

* [AMP] add fp16 op detection for custom device

* resolve conflicts
上级 066097e8
...@@ -52,13 +52,16 @@ OpSupportedInfos(const std::string& place, ...@@ -52,13 +52,16 @@ OpSupportedInfos(const std::string& place,
{"CPU", &platform::is_cpu_place}, {"CPU", &platform::is_cpu_place},
{"XPU", &platform::is_xpu_place}, {"XPU", &platform::is_xpu_place},
{"CUSTOM_DEVICE", &platform::is_custom_place}, {"CUSTOM_DEVICE", &platform::is_custom_place},
#ifdef PADDLE_WITH_CUSTOM_DEVICE
{query_place, &platform::is_custom_place},
#endif
}; };
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(is_target_place.count(query_place),
is_target_place.count(query_place), 0,
0, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "The argument `place` should be 'GPU', 'CPU', 'XPU' or "
"The argument `place` should be 'GPU', 'CPU', 'XPU', but got '%s'.", "other Custom Device, but got '%s'.",
place)); place));
std::unordered_set<std::string> all_ops; std::unordered_set<std::string> all_ops;
const auto& op_info = framework::OpInfoMap::Instance().map(); const auto& op_info = framework::OpInfoMap::Instance().map();
......
...@@ -98,7 +98,7 @@ def _get_sys_unsupported_list(dtype): ...@@ -98,7 +98,7 @@ def _get_sys_unsupported_list(dtype):
elif isinstance( elif isinstance(
paddle.framework._current_expected_place(), paddle.CustomPlace paddle.framework._current_expected_place(), paddle.CustomPlace
): ):
device = 'CUSTOM_DEVICE' device = paddle.framework._current_expected_place().get_device_type()
else: else:
device = 'GPU' device = 'GPU'
all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type) all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册