diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 90d47977cdf604b81366fe5a0d18a4bed9b6db01..82d2e741e9de852823726f91a6f2d7370c8d0b0e 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -40,7 +40,8 @@ inline bool NeedTransformPlace(const paddle::platform::Place& input, bool ret = input.GetType() == AllocationType::GPUPINNED || (transform_flag.need_trans_backend() && target != Backend::ALL_BACKEND && - phi::TransToPhiBackend(input) != target); + phi::TransToPhiBackend(input) != + (target != Backend::GPUDNN ? target : Backend::GPU)); return ret; }