未验证 提交 b436e5fa 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] Refine variable name (#34330)

* fix variable name

* fix variable name
上级 08c5b1d1
...@@ -140,20 +140,22 @@ class MatMulGradNPUKernel : public framework::OpKernel<T> { ...@@ -140,20 +140,22 @@ class MatMulGradNPUKernel : public framework::OpKernel<T> {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
if ((x->dims().size() == 3) && (dout->dims().size() == 3) && if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
(dy->dims().size() == 2)) { (dy->dims().size() == 2)) {
framework::Tensor dout_; framework::Tensor dout_tmp;
dout_.ShareDataWith(*dout); dout_tmp.ShareDataWith(*dout);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims()); std::vector<int> vec_dim =
framework::vectorize<int>(dout_tmp.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v)); dout_tmp.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_; framework::Tensor x_tmp;
x_.ShareDataWith(*x); x_tmp.ShareDataWith(*x);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims()); std::vector<int> vec_dim_x =
framework::vectorize<int>(x_tmp.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]}; vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v)); x_tmp.Resize(framework::make_ddim(vec_dim_x_v));
const auto& runner_dy = const auto& runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy}, NpuOpRunner("MatMul", {x_tmp, dout_tmp}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}}); {{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream); runner_dy.Run(stream);
} else { } else {
......
...@@ -140,20 +140,22 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -140,20 +140,22 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
if ((x->dims().size() == 3) && (dout->dims().size() == 3) && if ((x->dims().size() == 3) && (dout->dims().size() == 3) &&
(dy->dims().size() == 2)) { (dy->dims().size() == 2)) {
framework::Tensor dout_; framework::Tensor dout_tmp;
dout_.ShareDataWith(*dout); dout_tmp.ShareDataWith(*dout);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims()); std::vector<int> vec_dim =
framework::vectorize<int>(dout_tmp.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v)); dout_tmp.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_; framework::Tensor x_tmp;
x_.ShareDataWith(*x); x_tmp.ShareDataWith(*x);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims()); std::vector<int> vec_dim_x =
framework::vectorize<int>(x_tmp.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]}; vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v)); x_tmp.Resize(framework::make_ddim(vec_dim_x_v));
const auto& runner_dy = const auto& runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy}, NpuOpRunner("MatMul", {x_tmp, dout_tmp}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}}); {{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream); runner_dy.Run(stream);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册