未验证 提交 98d25314 编写于 作者: P pangyoki 提交者: GitHub

change TensorCopy to ShareDataWith in matmul_grad op (#33755)

上级 3946afc4
......@@ -141,17 +141,13 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
(dy->dims().size() == 2)) {
framework::Tensor dout_;
TensorCopy(*dout, ctx.GetPlace(), &dout_);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
dout_.ShareDataWith(*dout);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_;
TensorCopy(*x, ctx.GetPlace(), &x_);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
x_.ShareDataWith(*x);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册