未验证 提交 78200976 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix no pinned transform (#41300)

* fix no pinned trans

* fix cond error
上级 b0398c8e
...@@ -37,9 +37,10 @@ inline bool NeedTransformDataType(const DataType& input, ...@@ -37,9 +37,10 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input, inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target, const Backend& target,
const TransformFlag& transform_flag) { const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_backend() && bool ret =
target != Backend::ALL_BACKEND && input.GetType() == AllocationType::GPUPINNED ||
phi::TransToPhiBackend(input) != target; (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) != target);
return ret; return ret;
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/xpu/xpu_info.h" #include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
...@@ -31,6 +32,8 @@ Backend TransToPhiBackend(const phi::Place& place) { ...@@ -31,6 +32,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::CPU; return Backend::CPU;
} else if (allocation_type == phi::AllocationType::GPU) { } else if (allocation_type == phi::AllocationType::GPU) {
return Backend::GPU; return Backend::GPU;
} else if (allocation_type == phi::AllocationType::GPUPINNED) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) { } else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU; return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) { } else if (allocation_type == phi::AllocationType::NPU) {
...@@ -40,7 +43,8 @@ Backend TransToPhiBackend(const phi::Place& place) { ...@@ -40,7 +43,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
static_cast<size_t>(Backend::NUM_BACKENDS) + static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType())); GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else { } else {
return Backend::UNDEFINED; PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册