From da4413636bdbe217b36c6df691ec4411d91ad166 Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Sat, 18 Sep 2021 17:26:10 +0800 Subject: [PATCH] FixEighOP; Unified MatrixEighFunctor function (#35812) --- paddle/fluid/operators/eigh_op.cc | 13 ++-- paddle/fluid/operators/eigh_op.cu | 32 ++------- paddle/fluid/operators/eigh_op.h | 33 ++++----- .../operators/math/eigen_values_vectors.h | 34 +++++----- paddle/fluid/operators/svd_helper.h | 68 +++++++------------ .../fluid/tests/unittests/test_eigh_op.py | 4 +- 6 files changed, 68 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/operators/eigh_op.cc b/paddle/fluid/operators/eigh_op.cc index b3056bd43b..5577dfb8f8 100644 --- a/paddle/fluid/operators/eigh_op.cc +++ b/paddle/fluid/operators/eigh_op.cc @@ -47,12 +47,9 @@ class EighOp : public framework::OperatorWithKernel { input_dim[rank - 2], input_dim[rank - 1])); std::vector values_dim; - if (rank > 2) { - for (auto i = 0; i < rank - 1; i++) { - values_dim.emplace_back(input_dim[i]); - } - } else { - values_dim = {input_dim[1]}; + + for (auto i = 0; i < rank - 1; i++) { + values_dim.emplace_back(input_dim[i]); } ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); @@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel { "EighGrad"); OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EighGrad"); - OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")), "Input", "Eigenvalues@GRAD", "EighGrad"); - OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")), "Input", "Eigenvectors@GRAD", "EighGrad"); auto dims = ctx->GetInputDim("Eigenvectors"); auto x_grad_name = framework::GradVarName("X"); diff --git a/paddle/fluid/operators/eigh_op.cu b/paddle/fluid/operators/eigh_op.cu index cfc9eba450..61d2b66ea5 100644 --- a/paddle/fluid/operators/eigh_op.cu +++ b/paddle/fluid/operators/eigh_op.cu @@ -14,34 +14,14 @@ limitations under the License. */ #include "paddle/fluid/operators/eigh_op.h" -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class EighGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto input_var = ctx.Input("X"); - auto output_w_var = ctx.Output("Eigenvalues"); - auto output_v_var = ctx.Output("Eigenvectors"); - std::string lower = ctx.Attr("UPLO"); - bool is_lower = (lower == "L"); - math::MatrixEighFunctor functor; - functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; - REGISTER_OP_CUDA_KERNEL( - eigh, ops::EighGPUKernel, ops::EighGPUKernel, - ops::EighGPUKernel>, - ops::EighGPUKernel>); + eigh, ops::EighKernel, + ops::EighKernel, + ops::EighKernel>, + ops::EighKernel>); REGISTER_OP_CUDA_KERNEL( eigh_grad, diff --git a/paddle/fluid/operators/eigh_op.h b/paddle/fluid/operators/eigh_op.h index 0af38d44e5..085e7531dd 100644 --- a/paddle/fluid/operators/eigh_op.h +++ b/paddle/fluid/operators/eigh_op.h @@ -22,24 +22,17 @@ namespace operators { using Tensor = framework::Tensor; -template -using EigenTensor = framework::EigenTensor; -template -using EigenVector = framework::EigenVector; - template class EighKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto input_var = ctx.Input("X"); - auto output_w_var = ctx.Output("Eigenvalues"); - auto output_v_var = ctx.Output("Eigenvectors"); + auto input = ctx.Input("X"); + auto output_w = ctx.Output("Eigenvalues"); + auto output_v = ctx.Output("Eigenvectors"); std::string lower = ctx.Attr("UPLO"); bool is_lower = (lower == "L"); - math::MatrixEighFunctorCPU functor; - functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); + math::MatrixEighFunctor functor; + functor(ctx, *input, output_w, output_v, is_lower, true); } }; @@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& x_grad = *ctx.Output(framework::GradVarName("X")); x_grad.mutable_data(ctx.GetPlace()); - auto& output_w_var = *ctx.Input("Eigenvalues"); - auto& output_v_var = *ctx.Input("Eigenvectors"); + auto& output_w = *ctx.Input("Eigenvalues"); + auto& output_v = *ctx.Input("Eigenvectors"); auto& output_w_grad = *ctx.Input(framework::GradVarName("Eigenvalues")); auto& output_v_grad = *ctx.Input(framework::GradVarName("Eigenvectors")); - auto& dims = output_v_var.dims(); + auto& dims = output_v.dims(); const int m = dims[dims.size() - 1]; auto dito = math::DeviceIndependenceTensorOperations( ctx); - auto tV = dito.Transpose(dito.Conj(output_v_var)); - auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), - dito.Unsqueeze(output_w_var, -1)); + auto tV = dito.Transpose(dito.Conj(output_v)); + auto W = dito.template Sub(dito.Unsqueeze(output_w, -2), + dito.Unsqueeze(output_w, -1)); Tensor result = dito.Matmul(tV, output_v_grad); result.mutable_data(dims, ctx.GetPlace()); std::vector out_shape = framework::vectorize(dims); auto constant = dito.Fill(out_shape, 0.5); result = dito.Sub(result, dito.Conj(dito.Transpose(result))); result = dito.Mul(result, constant); - result = dito.Div_(result, W); + result = dito.Div(result, W); result = dito.DiagFill(m, m, m, 0, output_w_grad, result); - x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); + x_grad = dito.Matmul(output_v, dito.Matmul(result, tV)); } }; diff --git a/paddle/fluid/operators/math/eigen_values_vectors.h b/paddle/fluid/operators/math/eigen_values_vectors.h index 4e2d180e33..3c793c8906 100644 --- a/paddle/fluid/operators/math/eigen_values_vectors.h +++ b/paddle/fluid/operators/math/eigen_values_vectors.h @@ -16,7 +16,6 @@ #include "Eigen/Core" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/math/complex_functors.h" #include "paddle/fluid/operators/svd_helper.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/dynload/cusolver.h" @@ -26,10 +25,6 @@ namespace paddle { namespace operators { namespace math { -template -using EigenTensor = framework::EigenTensor; - template using InputMatrixMap = Eigen::Map< @@ -67,7 +62,7 @@ inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, eigenvalues = eigen_solver.eigenvalues().transpose(); if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors().transpose(); + eigenvectors = eigen_solver.eigenvectors(); } } } @@ -103,7 +98,7 @@ inline void ComputeComplexEigenvaluesAndVectors(T *x_data, eigenvalues = eigen_solver.eigenvalues().transpose(); if (has_vectors) { - eigenvectors = eigen_solver.eigenvectors().transpose(); + eigenvectors = eigen_solver.eigenvectors(); } } } @@ -117,11 +112,18 @@ inline int64_t GetBatchSize(framework::DDim dims) { return batch_size; } +template +struct MatrixEighFunctor { + void operator()(const framework::ExecutionContext &ctx, const Tensor &input, + Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, + bool has_vectors); +}; + // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // symmetric matrices, and uses the variable has_vectors to // control whether to return the eigenvectors. -template -struct MatrixEighFunctorCPU { +template +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, @@ -134,7 +136,8 @@ struct MatrixEighFunctorCPU { for (int64_t i = 0; i < dim_size - 2; i++) { batch_size *= dims[i]; } - auto dito = DeviceIndependenceTensorOperations(ctx); + auto dito = + DeviceIndependenceTensorOperations(ctx); Tensor input_tensor; TensorCopy(input, ctx.GetPlace(), &input_tensor); if (!is_lower) { @@ -157,9 +160,6 @@ struct MatrixEighFunctorCPU { ComputeFloatEigenvaluesAndVectors( x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); } - if (has_vectors) { - *eigen_vectors = dito.Transpose(*eigen_vectors); - } } }; @@ -169,7 +169,7 @@ struct MatrixEighFunctorCPU { // symmetric matrices on GPU, and uses the variable has_vectors // to control whether to return the eigenvectors. template -struct MatrixEighFunctor { +struct MatrixEighFunctor { public: void operator()(const framework::ExecutionContext &ctx, const Tensor &input, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, @@ -278,7 +278,8 @@ struct MatrixEighFunctor { #define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ template <> \ - inline void MatrixEighFunctor::EvdBuffer( \ + inline void \ + MatrixEighFunctor::EvdBuffer( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ int *lwork) const { \ @@ -292,7 +293,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); #define EVD_INSTANCE(ValueType, T, C, CastType) \ template <> \ - inline void MatrixEighFunctor::Evd( \ + inline void \ + MatrixEighFunctor::Evd( \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ int lwork, int *devInfo) const { \ diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index 71d106c211..d592c62d49 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations { framework::Tensor Div(const framework::Tensor& x, const framework::Tensor& y) { framework::Tensor ret; - std::vector out_shape = GetBroadcastShape({&x, &y}); - ret.Resize(framework::make_ddim(out_shape)); - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, DivFunctor(), &ret); + if (x.type() != y.type()) { + ret.mutable_data(x.dims(), context.GetPlace()); + auto x_vector = EigenVector::Flatten(x); + auto y_vector = EigenVector::Flatten(y); + auto out_vector = EigenVector::Flatten(ret); + auto& place = + *context.template device_context().eigen_device(); + out_vector.device(place) = x_vector / y_vector; + } else { + std::vector out_shape = GetBroadcastShape({&x, &y}); + ret.Resize(framework::make_ddim(out_shape)); + ElementwiseComputeEx, DeviceContext, T>( + context, &x, &y, -1, DivFunctor(), &ret); + } return ret; } framework::Tensor Add(const framework::Tensor& x, @@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations { NameInTensorMap inputs({{"X", {&x}}}); return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim); } - + // Support float and complex type subtraction,the default is T type + template framework::Tensor Sub(const framework::Tensor& x, const framework::Tensor& y) { framework::Tensor ret; @@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations { #if defined(__NVCC__) || defined(__HIPCC__) // For GPU, there is no need to define XxxInverseFunctor and call // ElementwiseComputeEx in two branches. - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, SubFunctor(), &ret); + ElementwiseComputeEx, DeviceContext, InT>( + context, &x, &y, -1, SubFunctor(), &ret); #endif } else { if (x.dims().size() >= y.dims().size()) { - ElementwiseComputeEx, DeviceContext, T>( - context, &x, &y, -1, SubFunctor(), &ret); + ElementwiseComputeEx, DeviceContext, InT>( + context, &x, &y, -1, SubFunctor(), &ret); } else { - ElementwiseComputeEx, DeviceContext, T>( - // This is copyed from elementwise_sub, which means we - // need reverse will xrank < yrank - context, &x, &y, -1, InverseSubFunctor(), &ret); + // This is copyed from elementwise_sub, which means we + // need reverse will xrank < yrank + ElementwiseComputeEx, DeviceContext, InT>( + context, &x, &y, -1, InverseSubFunctor(), &ret); } } return ret; @@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations { return out; } - // Support x and y are different data types - Tensor Div_(const Tensor& x, const Tensor& y) { - Tensor out; - out.mutable_data(x.dims(), context.GetPlace()); - auto x_vector = EigenVector::Flatten(x); - auto y_vector = EigenVector::Flatten(y); - auto out_vector = EigenVector::Flatten(out); - auto& place = - *context.template device_context().eigen_device(); - out_vector.device(place) = x_vector / y_vector; - return out; - } - - framework::Tensor Sub_(const framework::Tensor& x, - const framework::Tensor& y) { - framework::Tensor ret; - std::vector out_shape = GetBroadcastShape({&x, &y}); - ret.Resize(framework::make_ddim(out_shape)); - if (x.dims().size() >= y.dims().size()) { - ElementwiseComputeEx, DeviceContext, ValueType>( - context, &x, &y, -1, SubFunctor(), &ret); - } else { - ElementwiseComputeEx, DeviceContext, - ValueType>( - // This is copyed from elementwise_sub, which means we - // need reverse will xrank < yrank - context, &x, &y, -1, InverseSubFunctor(), &ret); - } - return ret; - } - private: const framework::ExecutionContext& context; BlasT GetBlas() { diff --git a/python/paddle/fluid/tests/unittests/test_eigh_op.py b/python/paddle/fluid/tests/unittests/test_eigh_op.py index e434364702..8e8c9df199 100644 --- a/python/paddle/fluid/tests/unittests/test_eigh_op.py +++ b/python/paddle/fluid/tests/unittests/test_eigh_op.py @@ -140,7 +140,7 @@ class TestEighAPI(unittest.TestCase): self.check_static_complex_result() def test_in_dynamic_mode(self): - paddle.disable_static(self.place) + paddle.disable_static() input_real_data = paddle.to_tensor(self.real_data) expected_w, expected_v = np.linalg.eigh(self.real_data) actual_w, actual_v = paddle.linalg.eigh(input_real_data) @@ -152,7 +152,7 @@ class TestEighAPI(unittest.TestCase): self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) def test_eigh_grad(self): - paddle.disable_static(self.place) + paddle.disable_static() x = paddle.to_tensor(self.complex_data, stop_gradient=False) w, v = paddle.linalg.eigh(x) (w.sum() + paddle.abs(v).sum()).backward() -- GitLab