未验证 提交 2d9d8f57 编写于 作者: B Baibaifan 提交者: GitHub

solove_matmulv2_npu_bugs (#32896)

上级 e48091db
...@@ -135,8 +135,21 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -135,8 +135,21 @@ class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
} }
if (dy) { if (dy) {
dy->mutable_data<T>(ctx.GetPlace()); dy->mutable_data<T>(ctx.GetPlace());
auto runner_dy = NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy}, framework::Tensor dout_;
{{"adj_x1", true}, {"adj_x2", false}}); TensorCopySync(*dout, ctx.GetPlace(), &dout_);
std::vector<int> vec_dim = framework::vectorize<int>(dout_.dims());
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]};
dout_.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_;
TensorCopySync(*x, ctx.GetPlace(), &x_);
std::vector<int> vec_dim_x = framework::vectorize<int>(x_.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
x_.Resize(framework::make_ddim(vec_dim_x_v));
auto runner_dy =
NpuOpRunner("MatMul", {x_, dout_}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream); runner_dy.Run(stream);
} }
} }
......
...@@ -927,6 +927,7 @@ def _linear(x, weight, bias=None, name=None): ...@@ -927,6 +927,7 @@ def _linear(x, weight, bias=None, name=None):
else: else:
helper = LayerHelper('linear', **locals()) helper = LayerHelper('linear', **locals())
dtype = x.dtype dtype = x.dtype
assert x.ndim < 4, "X latitude is not supported greater than 3 now."
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'linear') 'linear')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册