未验证 提交 ed9ec699 编写于 作者: W wanghuancoder 提交者: GitHub

fix stride legacy inplace bug (#56418)

上级 e5b71671
...@@ -410,6 +410,9 @@ void Tracer::TraceOp(const std::string& type, ...@@ -410,6 +410,9 @@ void Tracer::TraceOp(const std::string& type,
VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: " VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
<< use_default_attr_map; << use_default_attr_map;
std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs; std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;
std::map<phi::DenseTensor*, std::shared_ptr<phi::Allocation>>
need_backup_inputs2holder;
std::map<phi::DenseTensor*, phi::DDim> need_backup_inputs2strides;
if (FLAGS_use_stride_kernel) { if (FLAGS_use_stride_kernel) {
for (auto& iter : inplace_map) { for (auto& iter : inplace_map) {
auto inputs_iter = ins.find(iter.first); auto inputs_iter = ins.find(iter.first);
...@@ -426,11 +429,12 @@ void Tracer::TraceOp(const std::string& type, ...@@ -426,11 +429,12 @@ void Tracer::TraceOp(const std::string& type,
outputs_iter->second[i] outputs_iter->second[i]
->MutableVar() ->MutableVar()
->GetMutable<phi::DenseTensor>(); ->GetMutable<phi::DenseTensor>();
need_backup_inputs2holder[dense_tensor] = dense_tensor->Holder();
need_backup_inputs2strides[dense_tensor] = dense_tensor->strides();
} }
} }
} }
} }
TraceOpImpl<egr::EagerVariable>(type, TraceOpImpl<egr::EagerVariable>(type,
ins, ins,
outs, outs,
...@@ -443,7 +447,11 @@ void Tracer::TraceOp(const std::string& type, ...@@ -443,7 +447,11 @@ void Tracer::TraceOp(const std::string& type,
auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
for (auto& iter : need_backup_inputs2outputs) { for (auto& iter : need_backup_inputs2outputs) {
paddle::experimental::TransStride(dev_ctx, iter.second, iter.first); iter.first->ResetHolder(need_backup_inputs2holder[iter.first]);
iter.first->set_strides(need_backup_inputs2strides[iter.first]);
paddle::experimental::TransStrideLegacy(dev_ctx, iter.second, iter.first);
iter.second->ResetHolder(need_backup_inputs2holder[iter.first]);
iter.second->set_strides(need_backup_inputs2strides[iter.first]);
} }
} else { } else {
TraceOpImpl<egr::EagerVariable>(type, TraceOpImpl<egr::EagerVariable>(type,
......
...@@ -423,6 +423,56 @@ void TransStride(phi::DeviceContext* dev_ctx, ...@@ -423,6 +423,56 @@ void TransStride(phi::DeviceContext* dev_ctx,
} }
} }
void TransStrideLegacy(phi::DeviceContext* dev_ctx,
phi::DenseTensor* from,
phi::DenseTensor* to) {
if (to) {
auto* cpu_ctx = dynamic_cast<phi::CPUContext*>(dev_ctx);
if (cpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::CPUContext>(
*cpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
return;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* gpu_ctx = dynamic_cast<phi::GPUContext*>(dev_ctx);
if (gpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::GPUContext>(
*gpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
return;
}
#endif
#ifdef PADDLE_WITH_XPU
auto* xpu_ctx = dynamic_cast<phi::XPUContext*>(dev_ctx);
if (xpu_ctx) {
PD_VISIT_ALL_TYPES(to->dtype(), "StridedCopyKernel", ([&] {
phi::StridedCopyKernel<data_t, phi::XPUContext>(
*xpu_ctx,
*from,
phi::vectorize<int64_t>(to->dims()),
phi::vectorize<int64_t>(to->strides()),
to->offset(),
to);
}));
return;
}
#endif
}
}
void TransStride(phi::DeviceContext* dev_ctx, void TransStride(phi::DeviceContext* dev_ctx,
const std::vector<phi::DenseTensor*>& from, const std::vector<phi::DenseTensor*>& from,
const std::vector<phi::DenseTensor*>& to) { const std::vector<phi::DenseTensor*>& to) {
......
...@@ -133,6 +133,10 @@ void TransStride(phi::DeviceContext* dev_ctx, ...@@ -133,6 +133,10 @@ void TransStride(phi::DeviceContext* dev_ctx,
phi::SelectedRows* from, phi::SelectedRows* from,
phi::SelectedRows* to); phi::SelectedRows* to);
void TransStrideLegacy(phi::DeviceContext* dev_ctx,
phi::DenseTensor* from,
phi::DenseTensor* to);
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
/* ------------------ for auto parallel ----------------------- */ /* ------------------ for auto parallel ----------------------- */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册