diff --git a/paddle/fluid/operators/lu_op.cc b/paddle/fluid/operators/lu_op.cc index d3997f848e012c2ffdd905642f3eb7677aa955e3..aff6a77762fa389cadf30c9d042d078d979bf31c 100644 --- a/paddle/fluid/operators/lu_op.cc +++ b/paddle/fluid/operators/lu_op.cc @@ -149,7 +149,67 @@ class LUKernel : public framework::OpKernel { } }; +template +class LUOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("lu_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Out", this->Output("Out")); + retv->SetInput("Pivots", this->Output("Pivots")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +class LUGradOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_type = ctx->GetInputType("X", 0); + auto data_type = ctx->GetInputDataType("X", 0); + + ctx->SetOutputType(framework::GradVarName("X"), var_type, + framework::ALL_ELEMENTS); + ctx->SetOutputDataType(framework::GradVarName("X"), data_type, + framework::ALL_ELEMENTS); + } +}; + +class LUGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lu"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "lu"); + OP_INOUT_CHECK(ctx->HasInput("Pivots"), "Input", "Pivots", "lu"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "lu"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + DECLARE_INPLACE_OP_INFERER(LUOpInplaceInferer, {"X", "Out"}); +DECLARE_INPLACE_OP_INFERER(LUGradOpInplaceInferer, + {framework::GradVarName("Out"), + framework::GradVarName("X")}); } // namespace operators } // namespace paddle @@ -157,6 +217,13 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OPERATOR(lu, ops::LUOp, ops::LUOpMaker, ops::LUOpVarTypeInference, + ops::LUOpGradMaker, + ops::LUOpGradMaker, ops::LUOpInplaceInferer); +REGISTER_OPERATOR(lu_grad, ops::LUGradOp, ops::LUGradOpVarTypeInference, + ops::LUGradOpInplaceInferer); REGISTER_OP_CPU_KERNEL(lu, ops::LUKernel, ops::LUKernel); +REGISTER_OP_CPU_KERNEL(lu_grad, + ops::LUGradKernel, + ops::LUGradKernel); diff --git a/paddle/fluid/operators/lu_op.cu b/paddle/fluid/operators/lu_op.cu index bd6dc7124633eedb20ba16fbb40a2897075983e0..f395b39c17ea97ba1a93eb576b647c102c8142b7 100644 --- a/paddle/fluid/operators/lu_op.cu +++ b/paddle/fluid/operators/lu_op.cu @@ -152,5 +152,8 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL(lu, ops::LUCUDAKernel, ops::LUCUDAKernel); +REGISTER_OP_CUDA_KERNEL(lu_grad, + ops::LUGradKernel, + ops::LUGradKernel); #endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 57cab052a25437af7538396184c866cc03daae47..256219f2e3befe1e345ecb6a364941f2a36962e0 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -470,5 +470,228 @@ void Unpack_Pivot(const DeviceContext& dev_ctx, const framework::Tensor& Pivot, } } +template +class LUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto xin = ctx.Input("X"); + auto out = ctx.Input("Out"); + auto P = ctx.Input("Pivots"); + auto dout = ctx.Input(framework::GradVarName("Out")); + auto dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + const auto& dev_ctx = ctx.template device_context(); + math::DeviceIndependenceTensorOperations helper(ctx); + auto blas = math::GetBlas(ctx); + + auto xdims = xin->dims(); + int xrank = xdims.size(); + int64_t m = xdims[xrank - 2]; + int64_t n = xdims[xrank - 1]; + int64_t k = std::min(m, n); + + framework::Tensor L, U, L_narrow, U_narrow, L_narrow_mH, U_narrow_mH, + grad_narrow; + LU_Unpack(dev_ctx, out, &L, &U); + + Tensor_narrow(ctx, &L, &L_narrow, 0, k, 0, k); + Tensor_narrow(ctx, &U, &U_narrow, 0, k, 0, k); + Tensor_narrow(ctx, dout, &grad_narrow, 0, k, 0, k); + auto graddims = grad_narrow.dims(); + + Tensor_Conj(dev_ctx, L_narrow, &L_narrow_mH); + Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); + L_narrow_mH = helper.Transpose(L_narrow_mH); + U_narrow_mH = helper.Transpose(U_narrow_mH); + + auto LmHdims = L_narrow_mH.dims(); + auto UmHdims = U_narrow_mH.dims(); + + framework::Tensor phi_L, phi_U, phi, psi; + phi_L.Resize(LmHdims); + phi_L.mutable_data(ctx.GetPlace()); + phi_U.Resize(UmHdims); + phi_U.mutable_data(ctx.GetPlace()); + auto mat_dim_l = math::CreateMatrixDescriptor(LmHdims, 0, false); + auto mat_dim_u = math::CreateMatrixDescriptor(UmHdims, 0, false); + auto mat_dim_g = math::CreateMatrixDescriptor(graddims, 0, false); + blas.MatMul(L_narrow_mH, mat_dim_l, grad_narrow, mat_dim_g, + static_cast(1), &phi_L, static_cast(0)); + + blas.MatMul(grad_narrow, mat_dim_g, U_narrow_mH, mat_dim_u, + static_cast(1), &phi_U, static_cast(0)); + + auto phil_rank = LmHdims.size(); + auto phiu_rank = UmHdims.size(); + platform::ForRange l_for_range(dev_ctx, phi_L.numel()); + TrilTriuCompute tril_computer(phi_L.data(), -1, true, + LmHdims[phil_rank - 2], + LmHdims[phil_rank - 1], phi_L.data()); + l_for_range(tril_computer); + + platform::ForRange u_for_range(dev_ctx, phi_U.numel()); + TrilTriuCompute triu_computer(phi_U.data(), 0, false, + UmHdims[phiu_rank - 2], + UmHdims[phiu_rank - 1], phi_U.data()); + u_for_range(triu_computer); + + Tensor_Add(dev_ctx, phi_L, phi_U, &phi); + psi.Resize(xdims); + psi.mutable_data(ctx.GetPlace()); + math::SetConstant setter; + setter(dev_ctx, &psi, static_cast(0)); + + std::vector axes = {xrank - 2, xrank - 1}; + std::vector slice_starts(2, 0); + std::vector slice_ends(2, 0); + auto valuedims = vectorize(xdims); + + framework::Tensor Pmat; + Unpack_Pivot(dev_ctx, *P, &Pmat, m, k); + if (m <= n) { + if (k < n) { + framework::Tensor U_complement, U_grad_complement, phi_complement, + phi_complement_l; + Tensor_narrow(ctx, &U, &U_complement, 0, k, k, n); + Tensor_narrow(ctx, dout, &U_grad_complement, 0, k, k, + n); + framework::Tensor U_complement_mH = helper.Transpose(U_complement); + + Tensor_Conj(dev_ctx, U_complement_mH, + &U_complement_mH); + + auto mat_dim_g = + math::CreateMatrixDescriptor(U_grad_complement.dims(), 0, false); + auto mat_dim_u = + math::CreateMatrixDescriptor(U_complement_mH.dims(), 0, false); + auto phidims = UmHdims; + phidims[UmHdims.size() - 2] = k; + phidims[UmHdims.size() - 1] = k; + phi_complement.Resize(phidims); + phi_complement.mutable_data(ctx.GetPlace()); + blas.MatMul(U_grad_complement, mat_dim_g, U_complement_mH, mat_dim_u, + static_cast(1), &phi_complement, static_cast(0)); + + phi_complement_l.Resize(phidims); + phi_complement_l.mutable_data(ctx.GetPlace()); + const auto H = phidims[phidims.size() - 2]; + const auto W = phidims[phidims.size() - 1]; + platform::ForRange x_for_range(dev_ctx, + phi_complement.numel()); + TrilTriuCompute tril_computer(phi_complement.data(), -1, true, H, + W, phi_complement_l.data()); + x_for_range(tril_computer); + + Tensor_Sub(dev_ctx, phi, phi_complement_l, &phi); + + slice_starts[0] = 0; + slice_starts[1] = k; + slice_ends[0] = k; + slice_ends[1] = n; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = n - k; + SetValueCompute_dispatch( + ctx, &psi, &U_grad_complement, &psi, axes, &slice_starts, + &slice_ends, valuedims, xrank); + } + + framework::Tensor psi_principal, phi_mH, psi_tmp; + Tensor_Conj(dev_ctx, phi, &phi_mH); + phi_mH = helper.Transpose(phi_mH); + triangular_solve(dev_ctx, U_narrow, phi_mH, + &psi_principal, true, false, false); + + Tensor_Conj(dev_ctx, psi_principal, &psi_principal); + psi_principal = helper.Transpose(psi_principal); + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + + SetValueCompute_dispatch(ctx, &psi, &psi_principal, + &psi, axes, &slice_starts, + &slice_ends, valuedims, xrank); + triangular_solve(dev_ctx, L_narrow_mH, psi, &psi_tmp, + true, false, true); + + auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(psi_tmp.dims(), 0, false); + blas.MatMul(Pmat, mat_dim_p, psi_tmp, mat_dim_b, static_cast(1), dx, + static_cast(0)); + } else { + framework::Tensor L_complement, L_grad_complement, phi_complement, + phi_complement_u; + Tensor_narrow(ctx, &L, &L_complement, k, m, 0, k); + Tensor_narrow(ctx, dout, &L_grad_complement, k, m, 0, + k); + framework::Tensor L_complement_mH = helper.Transpose(L_complement); + Tensor_Conj(dev_ctx, L_complement_mH, &L_complement_mH); + + auto mat_dim_g = + math::CreateMatrixDescriptor(L_grad_complement.dims(), 0, false); + auto mat_dim_u = + math::CreateMatrixDescriptor(L_complement_mH.dims(), 0, false); + auto phidims = LmHdims; + phidims[LmHdims.size() - 2] = k; + phidims[LmHdims.size() - 1] = k; + phi_complement.Resize(phidims); + phi_complement.mutable_data(ctx.GetPlace()); + blas.MatMul(L_complement_mH, mat_dim_u, L_grad_complement, mat_dim_g, + static_cast(1), &phi_complement, static_cast(0)); + + phi_complement_u.Resize(phidims); + phi_complement_u.mutable_data(ctx.GetPlace()); + const auto H = phidims[phidims.size() - 2]; + const auto W = phidims[phidims.size() - 1]; + platform::ForRange x_for_range(dev_ctx, + phi_complement.numel()); + TrilTriuCompute triu_computer(phi_complement.data(), 0, false, H, W, + phi_complement_u.data()); + x_for_range(triu_computer); + + Tensor_Sub(dev_ctx, phi, phi_complement_u, &phi); + + slice_starts[0] = k; + slice_starts[1] = 0; + slice_ends[0] = m; + slice_ends[1] = k; + valuedims[xrank - 2] = m - k; + valuedims[xrank - 1] = k; + SetValueCompute_dispatch(ctx, &psi, &L_grad_complement, + &psi, axes, &slice_starts, + &slice_ends, valuedims, xrank); + framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH; + triangular_solve(dev_ctx, L_narrow_mH, phi, + &psi_principal, true, false, true); + slice_starts[0] = 0; + slice_starts[1] = 0; + slice_ends[0] = k; + slice_ends[1] = k; + valuedims[xrank - 2] = k; + valuedims[xrank - 1] = k; + + SetValueCompute_dispatch(ctx, &psi, &psi_principal, + &psi, axes, &slice_starts, + &slice_ends, valuedims, xrank); + + psi_tmp.Resize(psi.dims()); + psi_tmp.mutable_data(ctx.GetPlace()); + auto mat_dim_p = math::CreateMatrixDescriptor(Pmat.dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(psi.dims(), 0, false); + blas.MatMul(Pmat, mat_dim_p, psi, mat_dim_b, static_cast(1), &psi_tmp, + static_cast(0)); + psi_tmp = helper.Transpose(psi_tmp); + + Tensor_Conj(dev_ctx, U_narrow, &U_narrow_mH); + triangular_solve(dev_ctx, U_narrow_mH, psi_tmp, &psi, + true, false, false); + *dx = helper.Transpose(psi); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_lu_op.py b/python/paddle/fluid/tests/unittests/test_lu_op.py index badd713132cffa880c5092785b42064170ed68a6..1ab1b94f14b9f0996cfddcd01e3a86679a5c900e 100644 --- a/python/paddle/fluid/tests/unittests/test_lu_op.py +++ b/python/paddle/fluid/tests/unittests/test_lu_op.py @@ -140,6 +140,9 @@ class TestLUOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], ['Out']) + # m = n 2D class TestLUOp2(TestLUOp):