未验证 提交 be8e94aa 编写于 作者: B Baibaifan 提交者: GitHub

revert_matmulv2_npu (#33014)

上级 e409c7ce
......@@ -135,21 +135,8 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
framework::Tensor dout_;
TensorCopySync(*dout, ctx.GetPlace(), &dout_);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_;
TensorCopySync(*x, ctx.GetPlace(), &x_);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v));
auto runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
auto runner_dy = NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
runner_dy.Run(stream);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册