未验证 提交 78200976 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Fix no pinned transform (#41300)

* fix no pinned trans

* fix cond error
上级 b0398c8e
......@@ -37,9 +37,10 @@ inline bool NeedTransformDataType(const DataType& input,
inline bool NeedTransformPlace(const paddle::platform::Place& input,
const Backend& target,
const TransformFlag& transform_flag) {
bool ret = transform_flag.need_trans_backend() &&
target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) != target;
bool ret =
input.GetType() == AllocationType::GPUPINNED ||
(transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND &&
phi::TransToPhiBackend(input) != target);
return ret;
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/backends/xpu/xpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
......@@ -31,6 +32,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::CPU;
} else if (allocation_type == phi::AllocationType::GPU) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::GPUPINNED) {
return Backend::GPU;
} else if (allocation_type == phi::AllocationType::XPU) {
return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) {
......@@ -40,7 +43,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
static_cast<size_t>(Backend::NUM_BACKENDS) +
GetOrRegisterGlobalDeviceTypeId(place.GetDeviceType()));
} else {
return Backend::UNDEFINED;
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported transform %s to phi Backend.", place));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册