未验证 提交 83b942f3 编写于 作者: W wanghuancoder 提交者: GitHub

fix set value inplace strided bug (#56892)

上级 e2af9d56
......@@ -413,6 +413,7 @@ void Tracer::TraceOp(const std::string& type,
std::map<phi::DenseTensor*, std::shared_ptr<phi::Allocation>>
need_backup_inputs2holder;
std::map<phi::DenseTensor*, phi::DDim> need_backup_inputs2strides;
std::map<phi::DenseTensor*, size_t> need_backup_inputs2offset;
if (FLAGS_use_stride_kernel) {
for (auto& iter : inplace_map) {
auto inputs_iter = ins.find(iter.first);
......@@ -431,6 +432,7 @@ void Tracer::TraceOp(const std::string& type,
->GetMutable<phi::DenseTensor>();
need_backup_inputs2holder[dense_tensor] = dense_tensor->Holder();
need_backup_inputs2strides[dense_tensor] = dense_tensor->strides();
need_backup_inputs2offset[dense_tensor] = dense_tensor->offset();
}
}
}
......@@ -449,9 +451,11 @@ void Tracer::TraceOp(const std::string& type,
for (auto& iter : need_backup_inputs2outputs) {
iter.first->ResetHolder(need_backup_inputs2holder[iter.first]);
iter.first->set_strides(need_backup_inputs2strides[iter.first]);
iter.first->set_offset(need_backup_inputs2offset[iter.first]);
paddle::experimental::TransStrideLegacy(dev_ctx, iter.second, iter.first);
iter.second->ResetHolder(need_backup_inputs2holder[iter.first]);
iter.second->set_strides(need_backup_inputs2strides[iter.first]);
iter.second->set_offset(need_backup_inputs2offset[iter.first]);
}
} else {
TraceOpImpl<egr::EagerVariable>(type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册