From f74e32e9b83b1f47c419ce6aefcd2fd15a904d50 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Fri, 11 Aug 2023 18:10:58 +0800 Subject: [PATCH] [CustomDevice] fix amp op list (#56176) --- paddle/fluid/imperative/amp_auto_cast.cc | 16 +--------------- python/paddle/static/amp/fp16_lists.py | 5 +++++ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index bf6c32be2f3..cb8f369daf8 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -51,6 +51,7 @@ OpSupportedInfos(const std::string& place, {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, {"XPU", &platform::is_xpu_place}, + {"CUSTOM_DEVICE", &platform::is_custom_place}, }; PADDLE_ENFORCE_NE( is_target_place.count(query_place), @@ -76,12 +77,6 @@ OpSupportedInfos(const std::string& place, } } -#ifdef PADDLE_WITH_CUSTOM_DEVICE - auto is_custom_place = [&](const std::string& place) { - return is_target_place.count(place) && place != "CPU" && place != "GPU" && - place != "XPU"; - }; -#endif auto phi_kernels = phi::KernelFactory::Instance().kernels(); for (auto& kernel_pair : phi_kernels) { auto op_type = phi::TransToFluidOpName(kernel_pair.first); @@ -90,15 +85,6 @@ OpSupportedInfos(const std::string& place, all_ops.count(op_type) == 0) { continue; } -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (info_pair.first.backend() == phi::Backend::CUSTOM) { - if (is_custom_place(query_place)) { - VLOG(4) << op_type << " " << supported_ops.size(); - supported_ops.emplace(op_type); - } - continue; - } -#endif if (is_target_place[query_place]( phi::TransToPhiPlace(info_pair.first.backend(), false))) { VLOG(8) << op_type << " " << supported_ops.size(); diff --git a/python/paddle/static/amp/fp16_lists.py b/python/paddle/static/amp/fp16_lists.py index c3d8f20b04d..0860ec9bcbc 100644 --- a/python/paddle/static/amp/fp16_lists.py +++ b/python/paddle/static/amp/fp16_lists.py @@ -15,6 +15,7 @@ import copy import logging +import paddle from paddle.amp.amp_lists import ( EXTRA_BLACK_LIST, FP16_BLACK_LIST, @@ -94,6 +95,10 @@ def _get_sys_unsupported_list(dtype): device = None if core.is_compiled_with_xpu(): device = 'XPU' + elif isinstance( + paddle.framework._current_expected_place(), paddle.CustomPlace + ): + device = 'CUSTOM_DEVICE' else: device = 'GPU' all_ops, _, sys_unsupported_list = core.op_supported_infos(device, var_type) -- GitLab