未验证 提交 98d25314 编写于 作者: P pangyoki 提交者: GitHub

change TensorCopy to ShareDataWith in matmul_grad op (#33755)

上级 3946afc4
...@@ -141,17 +141,13 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -141,17 +141,13 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
if ((x->dims().size() == 3) && (dout->dims().size() == 3) && if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
(dy->dims().size() == 2)) { (dy->dims().size() == 2)) {
framework::Tensor dout_; framework::Tensor dout_;
TensorCopy(*dout, ctx.GetPlace(), &dout_); dout_.ShareDataWith(*dout);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims()); std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v)); dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_; framework::Tensor x_;
TensorCopy(*x, ctx.GetPlace(), &x_); x_.ShareDataWith(*x);
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims()); std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]}; vec_dim_x[2]};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册