From 78200976e33428e8da03e29289873cf577cf51f8 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 2 Apr 2022 20:59:13 +0800 Subject: [PATCH] [Phi] Fix no pinned transform (#41300) * fix no pinned trans * fix cond error --- paddle/phi/api/lib/data_transform.cc | 7 ++++--- paddle/phi/core/compat/convert_utils.cc | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index c1fc0fd907b..90d47977cdf 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -37,9 +37,10 @@ inline bool NeedTransformDataType(const DataType& input, inline bool NeedTransformPlace(const paddle::platform::Place& input, const Backend& target, const TransformFlag& transform_flag) { - bool ret = transform_flag.need_trans_backend() && - target != Backend::ALL_BACKEND && - phi::TransToPhiBackend(input) != target; + bool ret = + input.GetType() == AllocationType::GPUPINNED || + (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && + phi::TransToPhiBackend(input) != target); return ret; } diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index cc9c2caa889..c08dfa64c7f 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/xpu_info.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/enforce.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/device_manager.h" @@ -31,6 +32,8 @@ Backend TransToPhiBackend(const phi::Place& place) { return Backend::CPU; } else if (allocation_type == phi::AllocationType::GPU) { return Backend::GPU; + } else if (allocation_type == phi::AllocationType::GPUPINNED) { + return Backend::GPU; } else if (allocation_type == phi::AllocationType::XPU) { return Backend::XPU; } else if (allocation_type == phi::AllocationType::NPU) { @@ -40,7 +43,8 @@ Backend TransToPhiBackend(const phi::Place& place) { static_cast(Backend::NUM_BACKENDS) + GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); } else { - return Backend::UNDEFINED; + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported transform %s to phi Backend.", place)); } } -- GitLab