未验证 提交 2d9d8f57 编写于 作者: B Baibaifan 提交者: GitHub

solove_matmulv2_npu_bugs (#32896)

上级 e48091db
......@@ -135,8 +135,21 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
auto runner_dy = NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
framework::Tensor dout_;
TensorCopySync(*dout, ctx.GetPlace(), &dout_);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_;
TensorCopySync(*x, ctx.GetPlace(), &x_);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v));
auto runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream);
}
}
......
......@@ -927,6 +927,7 @@ def _linear(x, weight, bias=None, name=None):
else:
helper = LayerHelper('linear', **locals())
dtype = x.dtype
assert x.ndim < 4, "X latitude is not supported greater than 3 now."
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'linear')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册