From aeaf69b36de112e57e8e5bd01caa0e43a497c31b Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 9 Mar 2022 14:20:12 +0800 Subject: [PATCH] remove determinant deps for svd helper (#40235) --- paddle/fluid/operators/determinant_op.h | 71 +++++++++++++------------ paddle/phi/kernels/full_kernel.h | 12 +++++ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h index 463a707ccf..f89ecd3722 100644 --- a/paddle/fluid/operators/determinant_op.h +++ b/paddle/fluid/operators/determinant_op.h @@ -19,11 +19,17 @@ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/for_range.h" -#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" +#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" namespace paddle { namespace operators { @@ -172,7 +178,7 @@ template class DeterminantGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = context.template device_context(); + auto& orig_dev_ctx = context.template device_context(); const auto* input = context.Input("Input"); const auto* det = context.Input("Out"); const auto* grad = @@ -200,15 +206,18 @@ class DeterminantGradKernel : public framework::OpKernel { // checked in forward, pass } + auto& dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE&>( + orig_dev_ctx); + // Check Whether the matrix is invertible // (matrix A not invertible) == (det(A)=0) if (!CheckMatrixInvertible(context, det)) { // The matrix is not invertible VLOG(3) << "The input matrix not invertible!"; ddet->Resize(input->dims()); - ddet->mutable_data(context.GetPlace()); - phi::funcs::SetConstant zero; - zero(dev_ctx, ddet, static_cast(0.0f)); + phi::Full(dev_ctx, phi::vectorize(input->dims()), static_cast(0.0f), + ddet); return; } @@ -218,8 +227,6 @@ class DeterminantGradKernel : public framework::OpKernel { // we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2, // -1) - math::DeviceIndependenceTensorOperations helper(context); - // First: inverse(A) framework::Tensor inverse_A; // A must be square matrices! @@ -227,26 +234,28 @@ class DeterminantGradKernel : public framework::OpKernel { inverse_A.mutable_data(context.GetPlace()); phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, *input, &inverse_A); + mat_inv(orig_dev_ctx, *input, &inverse_A); VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); // Second: inverse(A).transpose(-2, -1) - framework::Tensor transpose_inverse_A = helper.Transpose(inverse_A); + framework::Tensor transpose_inverse_A = + phi::TransposeLast2Dim(dev_ctx, inverse_A); + VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: " << transpose_inverse_A.dims(); // Third: dA * |A| - auto mul_dA_detA = helper.Mul(*grad, *det); + auto mul_dA_detA = phi::Multiply(dev_ctx, *grad, *det); VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims(); // Fourth: unsqueeze(dA * |A|, [-1, -2]) - auto unsqueeze1 = helper.Unsqueeze(mul_dA_detA, -1); - auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1); + auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2); VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims(); // Finally: unsqueeze(dA * |A|) * inverse(A) - auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + auto res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims(); @@ -331,7 +340,7 @@ template class SlogDeterminantGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto& dev_ctx = context.template device_context(); + auto& orig_dev_ctx = context.template device_context(); const auto* input = context.Input("Input"); const auto* slogdet = context.Input("Out"); const auto* grad = @@ -353,6 +362,10 @@ class SlogDeterminantGradKernel : public framework::OpKernel { input->dims().size() - grad->dims().size())); } + auto& dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE&>( + orig_dev_ctx); + // Check Whether the matrix is invertible // (matrix A not invertible) == (absslogdet(A)=0) auto slogdet_vec = slogdet->Split(1, 0); @@ -361,9 +374,8 @@ class SlogDeterminantGradKernel : public framework::OpKernel { // The matrix is not invertible VLOG(3) << "The input matrix not invertible!"; dslogdet->Resize(input->dims()); - dslogdet->mutable_data(context.GetPlace()); - phi::funcs::SetConstant zero; - zero(dev_ctx, dslogdet, std::numeric_limits::quiet_NaN()); + phi::Full(dev_ctx, phi::vectorize(input->dims()), + std::numeric_limits::quiet_NaN(), dslogdet); return; } @@ -373,8 +385,6 @@ class SlogDeterminantGradKernel : public framework::OpKernel { // we set dsl|A| = unsqueeze(dslA, [-1, -2]) * // inverse(A).conj().transpose(-2, -1) - math::DeviceIndependenceTensorOperations helper(context); - // First: inverse(A) framework::Tensor inverse_A; // A must be square matrices! @@ -382,25 +392,18 @@ class SlogDeterminantGradKernel : public framework::OpKernel { inverse_A.mutable_data(context.GetPlace()); phi::funcs::MatrixInverseFunctor mat_inv; - mat_inv(dev_ctx, *input, &inverse_A); + mat_inv(orig_dev_ctx, *input, &inverse_A); VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); // Second: inverse(A).conj() - framework::Tensor conj_inverse_A; - conj_inverse_A.Resize(inverse_A.dims()); - auto numel = input->numel(); - auto* conj_data = conj_inverse_A.mutable_data(context.GetPlace(), - size_t(numel * sizeof(T))); - - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::ConjFunctor functor(inverse_A.data(), numel, conj_data); - for_range(functor); + auto conj_inverse_A = phi::Conj(dev_ctx, inverse_A); VLOG(3) << "inverse(A).conj() dims: " << conj_inverse_A.dims(); // Third: inverse(A).conj().transpose(-2, -1) - framework::Tensor transpose_inverse_A = helper.Transpose(conj_inverse_A); + framework::Tensor transpose_inverse_A = + phi::TransposeLast2Dim(dev_ctx, conj_inverse_A); VLOG(3) << "inverse(A).conj().transpose(-2, -1) dims: " << transpose_inverse_A.dims(); @@ -417,12 +420,12 @@ class SlogDeterminantGradKernel : public framework::OpKernel { det_grad.Resize(det_grad.dims().reshape(det_grad_vec)); // Fifth: unsqueeze(dslA, [-1, -2]) - auto unsqueeze1 = helper.Unsqueeze(det_grad, -1); - auto unsqueeze2 = helper.Unsqueeze(unsqueeze1, -2); + auto unsqueeze1 = phi::funcs::Unsqueeze(det_grad, -1); + auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2); VLOG(3) << "unsqueezed(dslA, [-1, -2]) dims: " << unsqueeze2.dims(); // Finally: unsqueeze(dslA) * inverse(A) - auto res = helper.Mul(unsqueeze2, transpose_inverse_A); + auto res = phi::Multiply(dev_ctx, unsqueeze2, transpose_inverse_A); VLOG(3) << "unsqueeze(dslA) * inverse(A) dims: " << res.dims(); framework::TensorCopy(res, context.GetPlace(), dslogdet); diff --git a/paddle/phi/kernels/full_kernel.h b/paddle/phi/kernels/full_kernel.h index c44f048051..41fc96b6db 100644 --- a/paddle/phi/kernels/full_kernel.h +++ b/paddle/phi/kernels/full_kernel.h @@ -37,6 +37,18 @@ void FullLikeKernel(const Context& dev_ctx, DataType dtype, DenseTensor* out); +template +void Full(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { + FullKernel(dev_ctx, + shape, + val, + paddle::experimental::CppTypeToDataType::Type(), + out); +} + template DenseTensor Full(const Context& dev_ctx, const ScalarArray& shape, -- GitLab