From e7afa3917799b67e44044c95c10603bf626133cf Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 5 Mar 2022 19:02:26 +0800 Subject: [PATCH] [Phi] Remove eig op depend for svd_helper (#40174) * remove eig dep for svd helper * fix win failed --- paddle/fluid/operators/eig_op.h | 92 ++++++++++------- paddle/phi/kernels/complex_kernel.h | 60 ++++++++++- paddle/phi/kernels/funcs/diag_functor.h | 99 ++++++++++++++++++ paddle/phi/kernels/funcs/slice.h | 127 ++++++++++++++++++++++++ paddle/phi/kernels/funcs/unsqueeze.h | 41 ++++++++ paddle/phi/kernels/matmul_kernel.h | 4 +- 6 files changed, 379 insertions(+), 44 deletions(-) create mode 100644 paddle/phi/kernels/funcs/slice.h create mode 100644 paddle/phi/kernels/funcs/unsqueeze.h diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h index e9c6c1eb7e..5e4c83e1a4 100644 --- a/paddle/fluid/operators/eig_op.h +++ b/paddle/fluid/operators/eig_op.h @@ -18,12 +18,19 @@ #include #include #include "paddle/fluid/operators/math/matrix_solve.h" -#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/diag_functor.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/slice.h" +#include "paddle/phi/kernels/funcs/unsqueeze.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/matmul_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + #define EPSILON 1e-6 namespace paddle { @@ -214,12 +221,17 @@ class EigKernel : public framework::OpKernel { ApplyEigKernel>( *x, &real_values, &real_vectors, context); - auto dito = math::DeviceIndependenceTensorOperations< - DeviceContext, phi::dtype::Real, Tout>(context); + + auto& orig_dev_ctx = context.template device_context(); + auto& dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE&>( + orig_dev_ctx); // 1. extract real part & imag part from real_values - Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order}); - Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2}); + Tensor real_part = + phi::funcs::Slice(dev_ctx, real_values, {-1}, {0}, {order}); + Tensor imag_part = phi::funcs::Slice(dev_ctx, real_values, {-1}, + {order}, {order * 2}); // 2. construct complex values auto* real_part_data = real_part.data>(); @@ -233,7 +245,8 @@ class EigKernel : public framework::OpKernel { for_range(functor); // 3. construct complex vectors - Tensor real_vector_trans = dito.Transpose(real_vectors); + Tensor real_vector_trans = + phi::TransposeLast2Dim(dev_ctx, real_vectors); Tensor out_vectors_trans; out_vectors_trans.mutable_data(x->dims(), context.GetPlace()); ConstructComplexVectors, Tout>( @@ -251,45 +264,48 @@ class EigKernel : public framework::OpKernel { } }; -template +template void ComputeBackwardForComplexInput( const Tensor& V, const Tensor& L, const Tensor& gL, const Tensor& gV, - Tout* x_grad_data, int batch_count, int order, + T* x_grad_data, int batch_count, int order, const framework::ExecutionContext& context) { - auto dito = - math::DeviceIndependenceTensorOperations( - context); - - Tensor trans_v = dito.Transpose(V); - Tensor Vh = dito.Conj(trans_v); - Tensor Lconj = dito.Conj(L); - Tensor Econj = dito.Sub(dito.Unsqueeze(Lconj, -2), dito.Unsqueeze(Lconj, -1)); - Tensor VhgV = dito.Matmul(Vh, gV); - Tensor diag_real = dito.Real(VhgV); - Tensor diag_res = dito.BatchDiag(diag_real, batch_count); - Tensor diag_unsqueezed = dito.Unsqueeze(diag_res, -2); + auto& orig_dev_ctx = context.template device_context(); + auto& dev_ctx = static_cast< + const typename framework::ConvertToPhiContext::TYPE&>( + orig_dev_ctx); + + Tensor trans_v = phi::TransposeLast2Dim(dev_ctx, V); + Tensor Vh = phi::Conj(dev_ctx, trans_v); + Tensor Lconj = phi::Conj(dev_ctx, L); + Tensor Econj = phi::Subtract(dev_ctx, phi::funcs::Unsqueeze(Lconj, -2), + phi::funcs::Unsqueeze(Lconj, -1)); + Tensor VhgV = phi::Matmul(dev_ctx, Vh, gV); + Tensor diag_real = phi::Real(dev_ctx, VhgV); + Tensor diag_res = phi::funcs::BatchDiag(dev_ctx, diag_real, batch_count); + Tensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2); // turn diag_unsqueezed into complex auto numel = diag_unsqueezed.numel(); Tensor diag_unsqueezed_complex; - auto* data_diag_un = diag_unsqueezed.data>(); - auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data( + auto* data_diag_un = diag_unsqueezed.data>(); + auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data( diag_unsqueezed.dims(), context.GetPlace(), - static_cast(numel * sizeof(Tout))); - auto& dev_ctx = context.template device_context(); - platform::ForRange for_range(dev_ctx, numel); - phi::funcs::RealToComplexFunctor functor(data_diag_un, data_diag_un_com, - numel); + static_cast(numel * sizeof(T))); + + platform::ForRange for_range(orig_dev_ctx, numel); + phi::funcs::RealToComplexFunctor functor(data_diag_un, data_diag_un_com, + numel); for_range(functor); // real tensor multiply complex tensor in broadcast manner - Tensor res1 = dito.RealMulComplex(V, diag_unsqueezed_complex); - Tensor res2 = dito.Matmul(Vh, res1); - Tensor result = dito.Sub(VhgV, res2); + Tensor res1 = phi::Multiply(dev_ctx, V, diag_unsqueezed_complex); + Tensor res2 = phi::Matmul(dev_ctx, Vh, res1); + Tensor result = phi::Subtract(dev_ctx, VhgV, res2); - result.mutable_data(V.dims(), context.GetPlace()); - result = dito.Div(result, Econj); - result = dito.DiagFill(order, order, order, 0, gL, result); - Tensor rhs = dito.Matmul(result, Vh); + result.mutable_data(V.dims(), context.GetPlace()); + result = phi::Divide(dev_ctx, result, Econj); + result = + phi::funcs::DiagFill(dev_ctx, order, order, order, 0, gL, result); + Tensor rhs = phi::Matmul(dev_ctx, result, Vh); // solve linear system // solve(Vh, rhs, out, m, k) @@ -298,10 +314,10 @@ void ComputeBackwardForComplexInput( // x_grad: out int m = Vh.dims()[Vh.dims().size() - 1]; int k = rhs.dims()[rhs.dims().size() - 1]; - auto* matrix_data = Vh.data(); - auto* rhs_data = rhs.data(); - math::SolveLinearSystem(matrix_data, rhs_data, x_grad_data, m, k, - batch_count); + auto* matrix_data = Vh.data(); + auto* rhs_data = rhs.data(); + math::SolveLinearSystem(matrix_data, rhs_data, x_grad_data, m, k, + batch_count); } template diff --git a/paddle/phi/kernels/complex_kernel.h b/paddle/phi/kernels/complex_kernel.h index 3b3003392d..2c52001ece 100644 --- a/paddle/phi/kernels/complex_kernel.h +++ b/paddle/phi/kernels/complex_kernel.h @@ -24,6 +24,12 @@ namespace phi { template void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +template +void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + +template +void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); + // If T is complex template < typename T, @@ -50,10 +56,56 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { return x; } -template -void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +// If T is complex +template < + typename T, + typename Context, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +DenseTensor Real(const Context& dev_ctx, const DenseTensor& x) { + auto dense_out = phi::Empty(dev_ctx); + MetaTensor meta_out(&dense_out); + RealAndImagInferMeta(x, &meta_out); + RealKernel(dev_ctx, x, &dense_out); + return dense_out; +} -template -void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); +// If T is not complex +template < + typename T, + typename Context, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +DenseTensor Real(const Context& dev_ctx, const DenseTensor& x) { + return x; +} + +// If T is complex +template < + typename T, + typename Context, + std::enable_if_t>::value || + std::is_same>::value, + bool> = true> +DenseTensor Imag(const Context& dev_ctx, const DenseTensor& x) { + auto dense_out = phi::Empty(dev_ctx); + MetaTensor meta_out(&dense_out); + RealAndImagInferMeta(x, &meta_out); + ImagKernel(dev_ctx, x, &dense_out); + return dense_out; +} + +// If T is not complex +template < + typename T, + typename Context, + std::enable_if_t>::value && + !std::is_same>::value, + bool> = true> +DenseTensor Imag(const Context& dev_ctx, const DenseTensor& x) { + return x; +} } // namespace phi diff --git a/paddle/phi/kernels/funcs/diag_functor.h b/paddle/phi/kernels/funcs/diag_functor.h index a806d1583a..1862f5ec91 100644 --- a/paddle/phi/kernels/funcs/diag_functor.h +++ b/paddle/phi/kernels/funcs/diag_functor.h @@ -14,6 +14,14 @@ #pragma once +#include "paddle/phi/common/type_traits.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +// TODO(paddle-dev): Remove this file when we can call related Kernel directly + namespace phi { namespace funcs { @@ -25,5 +33,96 @@ inline int ComputeStride(int axis, phi::DDim dims) { return size; } +template +struct DiagAndFillFunctor { + DiagAndFillFunctor(const int m, + const int n, + const int num_lower_diags, + const int num_upper_diags, + const ValueType* scale, + const T* input, + T* output) + : m_(m), + n_(n), + num_lower_diags_(num_lower_diags), + num_upper_diags_(num_upper_diags), + scale_(scale), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int col = index % n_; + const int row = (index / n_) % m_; + const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_); + const int band_end = + (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1); + if (col < band_start || col >= band_end) { + output_[index] = input_[index]; + } else if (col == band_end - 1) { + output_[index] = static_cast(scale_[index % m_]); + } else { + output_[index] = input_[index]; + } + } + + private: + const int m_, n_, num_lower_diags_, num_upper_diags_; + const ValueType* scale_; + const T* input_; + T* output_; +}; + +template +DenseTensor DiagFill(const Context& dev_ctx, + const int m, + const int n, + const int num_lower_diags, + const int num_upper_diags, + const DenseTensor& scale, + const DenseTensor& input) { + DenseTensor out; + out.Resize(input.dims()); + dev_ctx.template Alloc(&out); + funcs::ForRange for_range(dev_ctx, input.numel()); + DiagAndFillFunctor diag_and_copy_functor( + m, + n, + num_lower_diags, + num_upper_diags, + scale.data(), + input.data(), + out.data()); + for_range(diag_and_copy_functor); + return out; +} + +template +DenseTensor BatchDiag(const Context& dev_ctx, const DenseTensor& x, int batch) { + DenseTensor out; + auto* x_data = x.data>(); + auto numel = x.numel(); + out.Resize(x.dims()); + auto* out_data = dev_ctx.template HostAlloc>( + &out, static_cast(numel * sizeof(phi::dtype::Real))); + + auto x_dims = x.dims(); + int num_dims = x_dims.size(); + std::vector out_shape; + + for (int i = 0; i < num_dims - 1; ++i) { + out_shape.push_back(x.dims()[i]); + } + out.Resize(phi::make_ddim(out_shape)); + int order = x.dims()[num_dims - 1]; + int stride_out = order * order; + int stride_in = order + 1; + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < order; ++j) { + out_data[i * order + j] = x_data[stride_out * i + stride_in * j]; + } + } + return out; +} + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/slice.h b/paddle/phi/kernels/funcs/slice.h new file mode 100644 index 0000000000..0a50dceb0a --- /dev/null +++ b/paddle/phi/kernels/funcs/slice.h @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +// TODO(paddle-dev): Remove this file when we can call related Kernel directly + +namespace phi { +namespace funcs { + +template +void EigenSliceWrapper(const Context& dev_ctx, + const DenseTensor* in, + const std::vector& start, + const std::vector& end, + DenseTensor* out) { + // Slice by call Eigen Tensor Function `.slice()` + size_t rank = in->dims().size(); + PADDLE_ENFORCE_EQ(start.size(), + rank, + errors::InvalidArgument( + "EigenSliceWrapper function start " + "argument must have the same length as input rank.")); + PADDLE_ENFORCE_EQ(end.size(), + rank, + errors::InvalidArgument( + "EigenSliceWrapper function end " + "argument must have the same length as input rank.")); + auto eigen_place_ptr = dev_ctx.eigen_device(); + auto eigen_place = *eigen_place_ptr; + auto out_t = phi::EigenTensor::From(*out, out->dims()); + auto in_t = phi::EigenTensor::From(*in, in->dims()); + Eigen::DSizes offsets_32bit, extents_32bit; + for (size_t i = 0; i < D; i++) { + offsets_32bit[i] = start[i]; + extents_32bit[i] = end[i]; + } + EigenSlice, T, D>::Eval( + eigen_place, + phi::To32BitIndex(out_t), + phi::To32BitIndex(in_t), + offsets_32bit, + extents_32bit); +} + +#define SLICE_RANK_CASE(N) \ + case N: { \ + EigenSliceWrapper(dev_ctx, &x, offset, extends, &ret); \ + break; \ + } + +template +DenseTensor Slice(const Context& dev_ctx, + const DenseTensor& x, + std::vector axes, + std::vector starts, + std::vector ends) { + DenseTensor ret; + std::vector new_axes = axes; + std::vector out_shape = phi::vectorize(x.dims()); + size_t rank = out_shape.size(); + PADDLE_ENFORCE_EQ( + axes.size(), + starts.size(), + errors::InvalidArgument("Slice Operator Argument Invalided")); + PADDLE_ENFORCE_EQ( + ends.size(), + starts.size(), + errors::InvalidArgument("Slice Operator Argument Invalided")); + for (unsigned int i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + if (axis < 0) axis = rank + axis; + new_axes[i] = axis; // change negative to positive + int st = starts[i]; + int ed = ends[i]; + PADDLE_ENFORCE_GT( + ed, + st, + errors::InvalidArgument("C++ Slice Operation Not Support End < Start")); + out_shape[axis] = ed - st; + } + std::vector offset(rank), extends(rank); + for (size_t i = 0; i < rank; ++i) { + offset[i] = 0; + extends[i] = x.dims()[i]; + } + for (size_t i = 0; i < new_axes.size(); ++i) { + offset[new_axes[i]] = starts[i]; + extends[new_axes[i]] = ends[i] - starts[i]; + } + ret.Resize(phi::make_ddim(out_shape)); + dev_ctx.template Alloc(&ret); + switch (rank) { + SLICE_RANK_CASE(1); + SLICE_RANK_CASE(2); + SLICE_RANK_CASE(3); + SLICE_RANK_CASE(4); + SLICE_RANK_CASE(5); + SLICE_RANK_CASE(6); + default: { + PADDLE_THROW( + errors::InvalidArgument("Invalid Rank number, " + "currently only support rank between 2~6")); + } + } + return ret; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/unsqueeze.h b/paddle/phi/kernels/funcs/unsqueeze.h new file mode 100644 index 0000000000..7b8a81471e --- /dev/null +++ b/paddle/phi/kernels/funcs/unsqueeze.h @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" + +// TODO(paddle-dev): Remove this file when we can call related Kernel directly + +namespace phi { +namespace funcs { + +inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) { + // don't copy data, only change the dims + DenseTensor out(x); + std::vector out_shape = phi::vectorize(x.dims()); + if (axis >= 0) { + auto index = (out_shape.begin() + axis); + out_shape.insert(index, 1); + } else if (axis < 0) { + auto index = (out_shape.end() + axis + 1); + out_shape.insert(index, 1); + } + out.Resize(phi::make_ddim(out_shape)); + return out; +} + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/matmul_kernel.h b/paddle/phi/kernels/matmul_kernel.h index 8fc060d2e3..1f1cb22c27 100644 --- a/paddle/phi/kernels/matmul_kernel.h +++ b/paddle/phi/kernels/matmul_kernel.h @@ -33,8 +33,8 @@ template DenseTensor Matmul(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, - bool transpose_x, - bool transpose_y) { + bool transpose_x = false, + bool transpose_y = false) { auto dense_out = Empty(dev_ctx); MetaTensor meta_out(&dense_out); MatmulInferMeta(x, y, transpose_x, transpose_y, &meta_out); -- GitLab