diff --git a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc index fec008e7a106e0ccf4c6f17f9c75335e42350218..f9b45d4bc441df067cd0dad4ff89df90a3c484ab 100644 --- a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc @@ -19,37 +19,64 @@ namespace phi { -std::vector ExtendDimsWithOnes(const std::vector &dims, - int new_size) { - std::vector new_dims(new_size, 1); - for (size_t i = 0; i < dims.size(); ++i) { - new_dims[new_size - dims.size() + i] = dims[i]; +void CalculateMatrixDims(const std::vector &x_dims, + const std::vector &y_dims, + const std::vector &out_dims, + std::vector *x_bd_dims, + std::vector *y_bd_dims, + std::vector *out_bd_dims, + bool trans_x, + bool trans_y) { + if (x_dims.size() == 1) { + (*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[0]; + } else if (x_dims.size() == 2) { + (*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[1]; + (*x_bd_dims)[x_bd_dims->size() - 2] = x_dims[0]; + } else { + for (size_t i = 0; i < x_dims.size(); ++i) { + (*x_bd_dims)[x_bd_dims->size() - x_dims.size() + i] = x_dims[i]; + } + } + if (y_dims.size() == 1) { + (*y_bd_dims)[x_bd_dims->size() - 2] = y_dims[0]; + } else if (y_dims.size() == 2) { + (*y_bd_dims)[y_bd_dims->size() - 1] = y_dims[1]; + (*y_bd_dims)[y_bd_dims->size() - 2] = y_dims[0]; + } else { + for (size_t i = 0; i < y_dims.size(); ++i) { + (*y_bd_dims)[y_bd_dims->size() - y_dims.size() + i] = y_dims[i]; + } + } + + for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) { + (*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]); } + int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2; + int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1; - return new_dims; + (*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx]; + (*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx]; } template void CalculateGradMatrixDims(const OneDNNContext &dev_ctx, DenseTensor *dx_tmp, DenseTensor *dy_tmp, - const std::vector &dx_dims, - const std::vector &dy_dims, std::vector *dx_bd_dims, std::vector *dy_bd_dims) { - for (size_t i = 0; i < dx_dims.size() - 2; ++i) { - if (dx_dims[i] != dy_dims[i]) { - if (dx_dims[i] == 1) { - (*dx_bd_dims)[i] = dy_dims[i]; + for (size_t i = 0; i < dx_bd_dims->size() - 2; ++i) { + if ((*dx_bd_dims)[i] != (*dy_bd_dims)[i]) { + if ((*dx_bd_dims)[i] == 1) { + (*dx_bd_dims)[i] = (*dy_bd_dims)[i]; } else { - (*dy_bd_dims)[i] = dx_dims[i]; + (*dy_bd_dims)[i] = (*dx_bd_dims)[i]; } } } - dx_tmp->Resize(make_ddim((*dx_bd_dims))); + dx_tmp->Resize(make_ddim(*dx_bd_dims)); dev_ctx.template Alloc(dx_tmp); - dy_tmp->Resize(make_ddim((*dy_bd_dims))); + dy_tmp->Resize(make_ddim(*dy_bd_dims)); dev_ctx.template Alloc(dy_tmp); } @@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, const DenseTensor *dx_tmp, DenseTensor *dx, const std::vector &dx_dims, - const std::vector &squeezed_dims) { + const std::vector &x_dims) { funcs::ReductionOneDNNHandler handler(dnnl::algorithm::reduction_sum, 0.0f, 0.0f, @@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, dev_ctx.GetPlace(), dx_tmp, dx, - dx_dims); + x_dims); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto dst_memory_p = handler.AcquireDstMemory(dx); @@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, reduction_p->execute(astream, reduction_args); astream.wait(); - - dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims)); } template @@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx, size_t ndims = std::max(x_dims.size(), y_dims.size()); ndims = std::max(ndims, 3); - if (x_dims.size() != ndims) { - x_dims = ExtendDimsWithOnes(x_dims, ndims); - } - if (y_dims.size() != ndims) { - y_dims = ExtendDimsWithOnes(y_dims, ndims); - } - if (dout_dims.size() != ndims) { - dout_dims = ExtendDimsWithOnes(dout_dims, ndims); - } - // in broadcasting scenario new memory is required because // reduce sum must be calculated upon broadcasted dims DenseTensor dx_tmp, dy_tmp; - std::vector dx_bd_dims(x_dims); - std::vector dy_bd_dims(y_dims); + std::vector dout_bd_dims(ndims, 1); + std::vector x_bd_dims(ndims, 1); + std::vector y_bd_dims(ndims, 1); + + CalculateMatrixDims(x_dims, + y_dims, + dout_dims, + &x_bd_dims, + &y_bd_dims, + &dout_bd_dims, + transpose_x, + transpose_y); + + std::vector dx_bd_dims(x_bd_dims); + std::vector dy_bd_dims(y_bd_dims); CalculateGradMatrixDims( - dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); + dev_ctx, &dx_tmp, &dy_tmp, &dx_bd_dims, &dy_bd_dims); if (transpose_x && transpose_y) { funcs::ExecuteMatmul( - dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp); + dev_ctx, y, dout, y_bd_dims, dout_bd_dims, true, true, &dx_tmp); funcs::ExecuteMatmul( - dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp); + dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, true, &dy_tmp); } else if (transpose_x) { funcs::ExecuteMatmul( - dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp); + dev_ctx, y, dout, y_bd_dims, dout_bd_dims, false, true, &dx_tmp); funcs::ExecuteMatmul( - dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp); + dev_ctx, x, dout, x_bd_dims, dout_bd_dims, false, false, &dy_tmp); } else if (transpose_y) { funcs::ExecuteMatmul( - dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp); + dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, false, &dx_tmp); funcs::ExecuteMatmul( - dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp); + dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, false, &dy_tmp); } else { funcs::ExecuteMatmul( - dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp); + dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, true, &dx_tmp); funcs::ExecuteMatmul( - dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp); + dev_ctx, x, dout, x_bd_dims, dout_bd_dims, true, false, &dy_tmp); } - if (x_dims != dx_bd_dims) { + if (x_bd_dims != dx_bd_dims) { ReduceSumForMatmulGradOutput( - dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims())); + dev_ctx, &dx_tmp, dx, dx_bd_dims, x_bd_dims); } else { *dx = std::move(dx_tmp); } - if (y_dims != dy_bd_dims) { + if (y_bd_dims != dy_bd_dims) { ReduceSumForMatmulGradOutput( - dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims())); + dev_ctx, &dy_tmp, dy, dy_bd_dims, y_bd_dims); } else { *dy = std::move(dy_tmp); } + dx->set_mem_desc(x.mem_desc()); dx->Resize(x.dims()); - dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims()))); + dy->set_mem_desc(y.mem_desc()); dy->Resize(y.dims()); - dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims()))); } template