From ed9ec69993c6007e3ac87b29b29d19774cc0d3bc Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Fri, 18 Aug 2023 15:35:25 +0800 Subject: [PATCH] fix stride legacy inplace bug (#56418) --- paddle/fluid/imperative/tracer.cc | 12 +++++-- paddle/phi/api/lib/api_gen_utils.cc | 50 +++++++++++++++++++++++++++++ paddle/phi/api/lib/api_gen_utils.h | 4 +++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index b20e27bebbe..f1374bc8f7b 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -410,6 +410,9 @@ void Tracer::TraceOp(const std::string& type, VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: " << use_default_attr_map; std::map need_backup_inputs2outputs; + std::map> + need_backup_inputs2holder; + std::map need_backup_inputs2strides; if (FLAGS_use_stride_kernel) { for (auto& iter : inplace_map) { auto inputs_iter = ins.find(iter.first); @@ -426,11 +429,12 @@ void Tracer::TraceOp(const std::string& type, outputs_iter->second[i] ->MutableVar() ->GetMutable(); + need_backup_inputs2holder[dense_tensor] = dense_tensor->Holder(); + need_backup_inputs2strides[dense_tensor] = dense_tensor->strides(); } } } } - TraceOpImpl(type, ins, outs, @@ -443,7 +447,11 @@ void Tracer::TraceOp(const std::string& type, auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); for (auto& iter : need_backup_inputs2outputs) { - paddle::experimental::TransStride(dev_ctx, iter.second, iter.first); + iter.first->ResetHolder(need_backup_inputs2holder[iter.first]); + iter.first->set_strides(need_backup_inputs2strides[iter.first]); + paddle::experimental::TransStrideLegacy(dev_ctx, iter.second, iter.first); + iter.second->ResetHolder(need_backup_inputs2holder[iter.first]); + iter.second->set_strides(need_backup_inputs2strides[iter.first]); } } else { TraceOpImpl(type, diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index d1e549c91c4..e3f58683d98 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -423,6 +423,56 @@ void TransStride(phi::DeviceContext* dev_ctx, } } +void TransStrideLegacy(phi::DeviceContext* dev_ctx, + phi::DenseTensor* from, + phi::DenseTensor* to) { + if (to) { + auto* cpu_ctx = dynamic_cast(dev_ctx); + if (cpu_ctx) { + PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] { + phi::StridedCopyKernel( + *cpu_ctx, + *from, + phi::vectorize(to->dims()), + phi::vectorize(to->strides()), + to->offset(), + to); + })); + return; + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + auto* gpu_ctx = dynamic_cast(dev_ctx); + if (gpu_ctx) { + PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] { + phi::StridedCopyKernel( + *gpu_ctx, + *from, + phi::vectorize(to->dims()), + phi::vectorize(to->strides()), + to->offset(), + to); + })); + return; + } +#endif +#ifdef PADDLE_WITH_XPU + auto* xpu_ctx = dynamic_cast(dev_ctx); + if (xpu_ctx) { + PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] { + phi::StridedCopyKernel( + *xpu_ctx, + *from, + phi::vectorize(to->dims()), + phi::vectorize(to->strides()), + to->offset(), + to); + })); + return; + } +#endif + } +} + void TransStride(phi::DeviceContext* dev_ctx, const std::vector& from, const std::vector& to) { diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index afe312b7096..1b552bf94ea 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -133,6 +133,10 @@ void TransStride(phi::DeviceContext* dev_ctx, phi::SelectedRows* from, phi::SelectedRows* to); +void TransStrideLegacy(phi::DeviceContext* dev_ctx, + phi::DenseTensor* from, + phi::DenseTensor* to); + #ifdef PADDLE_WITH_DISTRIBUTE /* ------------------ for auto parallel ----------------------- */ -- GitLab