diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 809164df2056cb4f4856a0b70ea5076351603199..129298edafcf9a42d5f2058786e946faffa6618b 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -348,6 +348,181 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, return dim; } +template +class MatMulDoubleGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + const framework::Tensor &b, bool trans_b, bool flag, + framework::Tensor *out) const { + out->mutable_data(context.GetPlace()); + auto blas = math::GetBlas(context); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + + int head_number = 1; +#if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) + head_number = context.Attr("head_number"); +#endif + + if (head_number <= 1 && a.dims().size() == 3 && b.dims().size() <= 2) { + // the transpose_X must be false, if is true, the transpose cost much time + if (!trans_a) { + mat_dim_a.height_ *= mat_dim_a.batch_size_; + mat_dim_a.batch_size_ = 0; + } + } + blas.MatMul(a, mat_dim_a, b, mat_dim_b, + static_cast(context.Attr("alpha")), out, + static_cast(flag)); + } + + void CalcInputGrad(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor &b, + bool trans_b, bool is_fold_init_dims_b, bool flag, + framework::Tensor *out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, flag, out); + } else { + auto &ctx = context.template device_context(); + MatMul(context, is_fold_init_dims_a + ? FoldInitDims(a) + : FoldHeadAndLastDims(ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : FoldHeadAndLastDims(ctx, b), + trans_b, flag, out); + } + } + + void Compute(const framework::ExecutionContext &context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = *context.Input("DOut"); + auto *ddx = context.Input("DDX"); + auto *ddy = context.Input("DDY"); + + auto *dx = context.Output("DX"); + auto *dy = context.Output("DY"); + auto *ddout = context.Output("DDOut"); + + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + framework::DDim ddout_dims; + if (ddout) { + ddout_dims = ddout->dims(); + if (ddout_dims != dout.dims()) { + ddout->Resize(dout.dims()); + } + } + + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = *ddx; + if (ddx_mat.dims() != x.dims()) { + ddx_mat.Resize(x.dims()); + } + if (dy) { + if (transpose_x && transpose_y) { + // dy = dout' * ddx' + CalcInputGrad(context, dout, true, true, ddx_mat, true, false, false, + dy); + } else if (transpose_x) { + // dy = ddx * dout + CalcInputGrad(context, ddx_mat, false, false, dout, false, true, + false, dy); + } else if (transpose_y) { + // dy = dout' * ddx + CalcInputGrad(context, dout, true, true, ddx_mat, false, true, false, + dy); + } else { + // dy = ddx' * dout + CalcInputGrad(context, ddx_mat, true, true, dout, false, true, false, + dy); + } + } + + if (ddout) { + CalcInputGrad(context, ddx_mat, transpose_x, true, y, transpose_y, + false, ddout_flag, ddout); + ddout_flag = true; + } + } + + if (ddy) { + auto ddy_mat = *ddy; + if (ddy_mat.dims() != y.dims()) { + ddy_mat.Resize(y.dims()); + } + if (dx) { + if (transpose_x && transpose_y) { + // dx = ddy' * dout' + CalcInputGrad(context, ddy_mat, true, true, dout, true, false, false, + dx); + } else if (transpose_x) { + // dx = ddy * dout' + CalcInputGrad(context, ddy_mat, false, false, dout, true, false, + false, dx); + } else if (transpose_y) { + // dx = dout * ddy + CalcInputGrad(context, dout, false, false, ddy_mat, false, true, + false, dx); + } else { + // dx = dout * ddy' + CalcInputGrad(context, dout, false, false, ddy_mat, true, false, + false, dx); + } + } + + if (ddout) { + CalcInputGrad(context, x, transpose_x, true, ddy_mat, transpose_y, + false, ddout_flag, ddout); + } + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + + if (ddout) { + if (ddout_dims != dout.dims()) { + ddout->Resize(ddout_dims); + } + } + } +}; + class MatMulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -647,6 +822,61 @@ class MatMulOpGradMaker : public framework::SingleGradOpMaker { retv->SetAttrMap(this->Attrs()); } }; + +class MatMulOpDoubleGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "matmul"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); + OP_INOUT_CHECK(context->HasInput("DOut"), "Input", "DOut", "matmul"); + + if (context->HasOutput("DX") && context->HasInput("DDY")) { + context->ShareDim("X", "DX"); + } + + if (context->HasOutput("DY") && context->HasInput("DDX")) { + context->ShareDim("Y", "DY"); + } + + if (context->HasOutput("DDOut") && + (context->HasInput("DDY") || context->HasInput("DDX"))) { + context->ShareDim("DOut", "DDOut"); + } + } +}; + +template +class MatMulOpDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("matmul_grad_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput("DOut", this->Input(framework::GradVarName("Out"))); + retv->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + retv->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y"))); + + auto ddx = this->OutputGrad(framework::GradVarName("X")); + auto ddy = this->OutputGrad(framework::GradVarName("Y")); + + if (!ddx.empty() || !ddy.empty()) { + retv->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out"))); + } + retv->SetOutput( + "DX", ddy.empty() ? this->EmptyInputGrad() : this->InputGrad("X")); + retv->SetOutput( + "DY", ddx.empty() ? this->EmptyInputGrad() : this->InputGrad("Y")); + + retv->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -654,7 +884,10 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, ops::MatMulOpGradMaker, ops::MatMulOpGradMaker); -REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); +REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad, + ops::MatMulOpDoubleGradMaker, + ops::MatMulOpDoubleGradMaker); +REGISTER_OPERATOR(matmul_grad_grad, ops::MatMulOpDoubleGrad); REGISTER_OP_CPU_KERNEL( matmul, ops::MatMulKernel, ops::MatMulKernel); @@ -663,6 +896,11 @@ REGISTER_OP_CPU_KERNEL( ops::MatMulGradKernel, ops::MatMulGradKernel); +REGISTER_OP_CPU_KERNEL( + matmul_grad_grad, + ops::MatMulDoubleGradKernel, + ops::MatMulDoubleGradKernel); + #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL( matmul, ops::MatMulKernel, @@ -675,4 +913,8 @@ REGISTER_OP_CUDA_KERNEL( ops::MatMulGradKernel, ops::MatMulGradKernel); +REGISTER_OP_CUDA_KERNEL( + matmul_grad_grad, + ops::MatMulDoubleGradKernel, + ops::MatMulDoubleGradKernel); #endif diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 5d1e016287e07a8505336e6cb447c0e1b29a2ec2..bf1955c5711f52b9478a137d647aa83d304dd03b 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -153,6 +153,38 @@ class TestMulDoubleGradCheck(unittest.TestCase): self.func(p) +class TestMatmulDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + eps = 0.005 + x_shapes = [[2], [2, 3], [2, 4, 3], [2, 3, 4, 5], [2, 3, 4]] + y_shapes = [[2], [3, 2], [2, 4, 5], [2, 3, 3, 5], [4, 3]] + transpose_xs = [False, True, True, False, False] + transpose_ys = [False, True, False, True, False] + dtypes = [np.float64, np.float64, np.float32, np.float32, np.float64] + typenames = ["float64", "float64", "float32", "float32", "float64"] + for i, (x_shape, y_shape, transpose_x, transpose_y, dtype, typename) \ + in enumerate(zip(x_shapes, y_shapes, transpose_xs, transpose_ys, dtypes, typenames)): + x = layers.create_parameter( + dtype=typename, shape=x_shape, name='x{}'.format(i)) + y = layers.create_parameter( + dtype=typename, shape=y_shape, name='y{}'.format(i)) + out = layers.matmul( + x, y, transpose_x, transpose_y, name='out{}'.format(i)) + + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype) + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + class TestReshapeDoubleGradCheck(unittest.TestCase): @prog_scope() def func(self, place):