diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index fb6c6b98695fc6234088fbb98553fdbcdcd0b6f9..8a83a29d4847d18de94c5207cc336a43bb6cb9e2 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -44,7 +44,6 @@ template void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const std::vector& reduce_dims, const paddle::framework::ExecutionContext& ctx) { - if (reduce_dims.empty()) return; #ifdef __NVCC__ auto stream = ctx.cuda_device_context().stream(); TensorReduce>( @@ -602,47 +601,48 @@ class MatMulV2GradKernel : public framework::OpKernel { // So we should avoid the case in reality. VLOG(3) << "It need cost much time to reduce sum for the broadcast and " "wastes the memory. So we should avoid the case in reality"; + Tensor dx_help, dy_help; if (transpose_x) { if (transpose_y) { // X'Y': dA = Y'G', dB = G'X' if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, dx, - true, true, ctx); + MatMulFunction(&y, &dout, y_dims, dout_dims, + &dx_help, true, true, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, dy, - true, true, ctx); + MatMulFunction(&dout, &x, dout_dims, x_dims, + &dy_help, true, true, ctx); } else { // X'Y: dX = YG', dY = XG if (dx) - MatMulFunction(&y, &dout, y_dims, dout_dims, dx, - false, true, ctx); + MatMulFunction(&y, &dout, y_dims, dout_dims, + &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, dy, - false, false, ctx); + MatMulFunction(&x, &dout, x_dims, dout_dims, + &dy_help, false, false, ctx); } } else { if (transpose_y) { // XY': dX = GY, dY = G'X if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, dx, - false, false, ctx); + MatMulFunction(&dout, &y, dout_dims, y_dims, + &dx_help, false, false, ctx); if (dy) - MatMulFunction(&dout, &x, dout_dims, x_dims, dy, - true, false, ctx); + MatMulFunction(&dout, &x, dout_dims, x_dims, + &dy_help, true, false, ctx); } else { // XY: dX = GY', dY = X'G if (dx) - MatMulFunction(&dout, &y, dout_dims, y_dims, dx, - false, true, ctx); + MatMulFunction(&dout, &y, dout_dims, y_dims, + &dx_help, false, true, ctx); if (dy) - MatMulFunction(&x, &dout, x_dims, dout_dims, dy, - true, false, ctx); + MatMulFunction(&x, &dout, x_dims, dout_dims, + &dy_help, true, false, ctx); } } // get help dims - const std::vector dx_help_dims = vectorize(dx->dims()); - const std::vector dy_help_dims = vectorize(dy->dims()); + const std::vector dx_help_dims = vectorize(dx_help.dims()); + const std::vector dy_help_dims = vectorize(dy_help.dims()); std::vector dx_broadcast_dims(ndim); std::vector dy_broadcast_dims(ndim); @@ -668,11 +668,21 @@ class MatMulV2GradKernel : public framework::OpKernel { } // reduce sum to get grad by ReduceSum if (dx) { - ReduceSumForMatmulGrad(dx, dx, dx_reduce_dims, ctx); + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + ReduceSumForMatmulGrad(&dx_help, dx, dx_reduce_dims, + ctx); + } dx->Resize(x.dims()); } if (dy) { - ReduceSumForMatmulGrad(dy, dy, dy_reduce_dims, ctx); + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + ReduceSumForMatmulGrad(&dy_help, dy, dy_reduce_dims, + ctx); + } dy->Resize(y.dims()); } } diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 1695058f7b3a2a559c3aa2f6d8871fcaa9fbed02..76172632c717118e1561cd1103f95c71067d2451 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -286,6 +286,30 @@ class TestMatMuklOp17(TestMatMulV2Op): self.trans_y = False +class TestMatMuklOpBroadcast1(TestMatMulV2Op): + """ + case 14_3 + """ + + def config(self): + self.x_shape = (3, 1, 10, 10) + self.y_shape = (1, 2, 10, 10) + self.trans_x = True + self.trans_y = True + + +class TestMatMuklOpBroadcast2(TestMatMulV2Op): + """ + case 14_4 + """ + + def config(self): + self.x_shape = (3, 1, 10, 10) + self.y_shape = (1, 2, 10, 10) + self.trans_x = False + self.trans_y = True + + #--------------------test matmul fp16--------------------