未验证 提交 e7afa391 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Remove eig op depend for svd_helper (#40174)

* remove eig dep for svd helper

* fix win failed
上级 4be5448b
...@@ -18,12 +18,19 @@ ...@@ -18,12 +18,19 @@
#include <algorithm> #include <algorithm>
#include <complex> #include <complex>
#include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#define EPSILON 1e-6 #define EPSILON 1e-6
namespace paddle { namespace paddle {
...@@ -214,12 +221,17 @@ class EigKernel : public framework::OpKernel<T> { ...@@ -214,12 +221,17 @@ class EigKernel : public framework::OpKernel<T> {
ApplyEigKernel<DeviceContext, phi::dtype::Real<T>>( ApplyEigKernel<DeviceContext, phi::dtype::Real<T>>(
*x, &real_values, &real_vectors, context); *x, &real_values, &real_vectors, context);
auto dito = math::DeviceIndependenceTensorOperations<
DeviceContext, phi::dtype::Real<T>, Tout>(context); auto& orig_dev_ctx = context.template device_context<DeviceContext>();
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// 1. extract real part & imag part from real_values // 1. extract real part & imag part from real_values
Tensor real_part = dito.Slice(real_values, {-1}, {0}, {order}); Tensor real_part =
Tensor imag_part = dito.Slice(real_values, {-1}, {order}, {order * 2}); phi::funcs::Slice<T>(dev_ctx, real_values, {-1}, {0}, {order});
Tensor imag_part = phi::funcs::Slice<T>(dev_ctx, real_values, {-1},
{order}, {order * 2});
// 2. construct complex values // 2. construct complex values
auto* real_part_data = real_part.data<phi::dtype::Real<T>>(); auto* real_part_data = real_part.data<phi::dtype::Real<T>>();
...@@ -233,7 +245,8 @@ class EigKernel : public framework::OpKernel<T> { ...@@ -233,7 +245,8 @@ class EigKernel : public framework::OpKernel<T> {
for_range(functor); for_range(functor);
// 3. construct complex vectors // 3. construct complex vectors
Tensor real_vector_trans = dito.Transpose(real_vectors); Tensor real_vector_trans =
phi::TransposeLast2Dim<T>(dev_ctx, real_vectors);
Tensor out_vectors_trans; Tensor out_vectors_trans;
out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace()); out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace());
ConstructComplexVectors<phi::dtype::Real<T>, Tout>( ConstructComplexVectors<phi::dtype::Real<T>, Tout>(
...@@ -251,45 +264,48 @@ class EigKernel : public framework::OpKernel<T> { ...@@ -251,45 +264,48 @@ class EigKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename Tout> template <typename DeviceContext, typename T>
void ComputeBackwardForComplexInput( void ComputeBackwardForComplexInput(
const Tensor& V, const Tensor& L, const Tensor& gL, const Tensor& gV, const Tensor& V, const Tensor& L, const Tensor& gL, const Tensor& gV,
Tout* x_grad_data, int batch_count, int order, T* x_grad_data, int batch_count, int order,
const framework::ExecutionContext& context) { const framework::ExecutionContext& context) {
auto dito = auto& orig_dev_ctx = context.template device_context<DeviceContext>();
math::DeviceIndependenceTensorOperations<DeviceContext, Tout, Tout>( auto& dev_ctx = static_cast<
context); const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
Tensor trans_v = dito.Transpose(V);
Tensor Vh = dito.Conj(trans_v); Tensor trans_v = phi::TransposeLast2Dim<T>(dev_ctx, V);
Tensor Lconj = dito.Conj(L); Tensor Vh = phi::Conj<T>(dev_ctx, trans_v);
Tensor Econj = dito.Sub(dito.Unsqueeze(Lconj, -2), dito.Unsqueeze(Lconj, -1)); Tensor Lconj = phi::Conj<T>(dev_ctx, L);
Tensor VhgV = dito.Matmul(Vh, gV); Tensor Econj = phi::Subtract<T>(dev_ctx, phi::funcs::Unsqueeze(Lconj, -2),
Tensor diag_real = dito.Real(VhgV); phi::funcs::Unsqueeze(Lconj, -1));
Tensor diag_res = dito.BatchDiag(diag_real, batch_count); Tensor VhgV = phi::Matmul<T>(dev_ctx, Vh, gV);
Tensor diag_unsqueezed = dito.Unsqueeze(diag_res, -2); Tensor diag_real = phi::Real<T>(dev_ctx, VhgV);
Tensor diag_res = phi::funcs::BatchDiag<T>(dev_ctx, diag_real, batch_count);
Tensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2);
// turn diag_unsqueezed into complex // turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel(); auto numel = diag_unsqueezed.numel();
Tensor diag_unsqueezed_complex; Tensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<Tout>>(); auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<T>>();
auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<Tout>( auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<T>(
diag_unsqueezed.dims(), context.GetPlace(), diag_unsqueezed.dims(), context.GetPlace(),
static_cast<size_t>(numel * sizeof(Tout))); static_cast<size_t>(numel * sizeof(T)));
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, numel); platform::ForRange<DeviceContext> for_range(orig_dev_ctx, numel);
phi::funcs::RealToComplexFunctor<Tout> functor(data_diag_un, data_diag_un_com, phi::funcs::RealToComplexFunctor<T> functor(data_diag_un, data_diag_un_com,
numel); numel);
for_range(functor); for_range(functor);
// real tensor multiply complex tensor in broadcast manner // real tensor multiply complex tensor in broadcast manner
Tensor res1 = dito.RealMulComplex(V, diag_unsqueezed_complex); Tensor res1 = phi::Multiply<T>(dev_ctx, V, diag_unsqueezed_complex);
Tensor res2 = dito.Matmul(Vh, res1); Tensor res2 = phi::Matmul<T>(dev_ctx, Vh, res1);
Tensor result = dito.Sub(VhgV, res2); Tensor result = phi::Subtract<T>(dev_ctx, VhgV, res2);
result.mutable_data<Tout>(V.dims(), context.GetPlace()); result.mutable_data<T>(V.dims(), context.GetPlace());
result = dito.Div(result, Econj); result = phi::Divide<T>(dev_ctx, result, Econj);
result = dito.DiagFill(order, order, order, 0, gL, result); result =
Tensor rhs = dito.Matmul(result, Vh); phi::funcs::DiagFill<T, T>(dev_ctx, order, order, order, 0, gL, result);
Tensor rhs = phi::Matmul<T>(dev_ctx, result, Vh);
// solve linear system // solve linear system
// solve(Vh, rhs, out, m, k) // solve(Vh, rhs, out, m, k)
...@@ -298,10 +314,10 @@ void ComputeBackwardForComplexInput( ...@@ -298,10 +314,10 @@ void ComputeBackwardForComplexInput(
// x_grad: out // x_grad: out
int m = Vh.dims()[Vh.dims().size() - 1]; int m = Vh.dims()[Vh.dims().size() - 1];
int k = rhs.dims()[rhs.dims().size() - 1]; int k = rhs.dims()[rhs.dims().size() - 1];
auto* matrix_data = Vh.data<Tout>(); auto* matrix_data = Vh.data<T>();
auto* rhs_data = rhs.data<Tout>(); auto* rhs_data = rhs.data<T>();
math::SolveLinearSystem<Tout>(matrix_data, rhs_data, x_grad_data, m, k, math::SolveLinearSystem<T>(matrix_data, rhs_data, x_grad_data, m, k,
batch_count); batch_count);
} }
template <typename DeviceContext, typename T, typename Tout> template <typename DeviceContext, typename T, typename Tout>
......
...@@ -24,6 +24,12 @@ namespace phi { ...@@ -24,6 +24,12 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
// If T is complex // If T is complex
template < template <
typename T, typename T,
...@@ -50,10 +56,56 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { ...@@ -50,10 +56,56 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return x; return x;
} }
template <typename T, typename Context> // If T is complex
void RealKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); template <
typename T,
typename Context,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
DenseTensor Real(const Context& dev_ctx, const DenseTensor& x) {
auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
RealAndImagInferMeta(x, &meta_out);
RealKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}
template <typename T, typename Context> // If T is not complex
void ImagKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); template <
typename T,
typename Context,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
DenseTensor Real(const Context& dev_ctx, const DenseTensor& x) {
return x;
}
// If T is complex
template <
typename T,
typename Context,
std::enable_if_t<std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
DenseTensor Imag(const Context& dev_ctx, const DenseTensor& x) {
auto dense_out = phi::Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out);
RealAndImagInferMeta(x, &meta_out);
ImagKernel<T>(dev_ctx, x, &dense_out);
return dense_out;
}
// If T is not complex
template <
typename T,
typename Context,
std::enable_if_t<!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
bool> = true>
DenseTensor Imag(const Context& dev_ctx, const DenseTensor& x) {
return x;
}
} // namespace phi } // namespace phi
...@@ -14,6 +14,14 @@ ...@@ -14,6 +14,14 @@
#pragma once #pragma once
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace phi { namespace phi {
namespace funcs { namespace funcs {
...@@ -25,5 +33,96 @@ inline int ComputeStride(int axis, phi::DDim dims) { ...@@ -25,5 +33,96 @@ inline int ComputeStride(int axis, phi::DDim dims) {
return size; return size;
} }
template <typename T, typename ValueType>
struct DiagAndFillFunctor {
DiagAndFillFunctor(const int m,
const int n,
const int num_lower_diags,
const int num_upper_diags,
const ValueType* scale,
const T* input,
T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
scale_(scale),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = input_[index];
} else if (col == band_end - 1) {
output_[index] = static_cast<T>(scale_[index % m_]);
} else {
output_[index] = input_[index];
}
}
private:
const int m_, n_, num_lower_diags_, num_upper_diags_;
const ValueType* scale_;
const T* input_;
T* output_;
};
template <typename T, typename ValueType, typename Context>
DenseTensor DiagFill(const Context& dev_ctx,
const int m,
const int n,
const int num_lower_diags,
const int num_upper_diags,
const DenseTensor& scale,
const DenseTensor& input) {
DenseTensor out;
out.Resize(input.dims());
dev_ctx.template Alloc<T>(&out);
funcs::ForRange<Context> for_range(dev_ctx, input.numel());
DiagAndFillFunctor<T, ValueType> diag_and_copy_functor(
m,
n,
num_lower_diags,
num_upper_diags,
scale.data<ValueType>(),
input.data<T>(),
out.data<T>());
for_range(diag_and_copy_functor);
return out;
}
template <typename T, typename Context>
DenseTensor BatchDiag(const Context& dev_ctx, const DenseTensor& x, int batch) {
DenseTensor out;
auto* x_data = x.data<phi::dtype::Real<T>>();
auto numel = x.numel();
out.Resize(x.dims());
auto* out_data = dev_ctx.template HostAlloc<phi::dtype::Real<T>>(
&out, static_cast<size_t>(numel * sizeof(phi::dtype::Real<T>)));
auto x_dims = x.dims();
int num_dims = x_dims.size();
std::vector<int> out_shape;
for (int i = 0; i < num_dims - 1; ++i) {
out_shape.push_back(x.dims()[i]);
}
out.Resize(phi::make_ddim(out_shape));
int order = x.dims()[num_dims - 1];
int stride_out = order * order;
int stride_in = order + 1;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < order; ++j) {
out_data[i * order + j] = x_data[stride_out * i + stride_in * j];
}
}
return out;
}
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace phi {
namespace funcs {
template <typename Context, typename T, size_t D>
void EigenSliceWrapper(const Context& dev_ctx,
const DenseTensor* in,
const std::vector<int>& start,
const std::vector<int>& end,
DenseTensor* out) {
// Slice by call Eigen Tensor Function `.slice()`
size_t rank = in->dims().size();
PADDLE_ENFORCE_EQ(start.size(),
rank,
errors::InvalidArgument(
"EigenSliceWrapper function start "
"argument must have the same length as input rank."));
PADDLE_ENFORCE_EQ(end.size(),
rank,
errors::InvalidArgument(
"EigenSliceWrapper function end "
"argument must have the same length as input rank."));
auto eigen_place_ptr = dev_ctx.eigen_device();
auto eigen_place = *eigen_place_ptr;
auto out_t = phi::EigenTensor<T, D>::From(*out, out->dims());
auto in_t = phi::EigenTensor<T, D>::From(*in, in->dims());
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = start[i];
extents_32bit[i] = end[i];
}
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place,
phi::To32BitIndex(out_t),
phi::To32BitIndex(in_t),
offsets_32bit,
extents_32bit);
}
#define SLICE_RANK_CASE(N) \
case N: { \
EigenSliceWrapper<Context, T, N>(dev_ctx, &x, offset, extends, &ret); \
break; \
}
template <typename T, typename Context>
DenseTensor Slice(const Context& dev_ctx,
const DenseTensor& x,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends) {
DenseTensor ret;
std::vector<int> new_axes = axes;
std::vector<int> out_shape = phi::vectorize<int>(x.dims());
size_t rank = out_shape.size();
PADDLE_ENFORCE_EQ(
axes.size(),
starts.size(),
errors::InvalidArgument("Slice Operator Argument Invalided"));
PADDLE_ENFORCE_EQ(
ends.size(),
starts.size(),
errors::InvalidArgument("Slice Operator Argument Invalided"));
for (unsigned int i = 0; i < axes.size(); ++i) {
int axis = axes[i];
if (axis < 0) axis = rank + axis;
new_axes[i] = axis; // change negative to positive
int st = starts[i];
int ed = ends[i];
PADDLE_ENFORCE_GT(
ed,
st,
errors::InvalidArgument("C++ Slice Operation Not Support End < Start"));
out_shape[axis] = ed - st;
}
std::vector<int> offset(rank), extends(rank);
for (size_t i = 0; i < rank; ++i) {
offset[i] = 0;
extends[i] = x.dims()[i];
}
for (size_t i = 0; i < new_axes.size(); ++i) {
offset[new_axes[i]] = starts[i];
extends[new_axes[i]] = ends[i] - starts[i];
}
ret.Resize(phi::make_ddim(out_shape));
dev_ctx.template Alloc<T>(&ret);
switch (rank) {
SLICE_RANK_CASE(1);
SLICE_RANK_CASE(2);
SLICE_RANK_CASE(3);
SLICE_RANK_CASE(4);
SLICE_RANK_CASE(5);
SLICE_RANK_CASE(6);
default: {
PADDLE_THROW(
errors::InvalidArgument("Invalid Rank number, "
"currently only support rank between 2~6"));
}
}
return ret;
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/dense_tensor.h"
// TODO(paddle-dev): Remove this file when we can call related Kernel directly
namespace phi {
namespace funcs {
inline const DenseTensor Unsqueeze(const DenseTensor& x, int axis = 0) {
// don't copy data, only change the dims
DenseTensor out(x);
std::vector<int> out_shape = phi::vectorize<int>(x.dims());
if (axis >= 0) {
auto index = (out_shape.begin() + axis);
out_shape.insert(index, 1);
} else if (axis < 0) {
auto index = (out_shape.end() + axis + 1);
out_shape.insert(index, 1);
}
out.Resize(phi::make_ddim(out_shape));
return out;
}
} // namespace funcs
} // namespace phi
...@@ -33,8 +33,8 @@ template <typename T, typename Context> ...@@ -33,8 +33,8 @@ template <typename T, typename Context>
DenseTensor Matmul(const Context& dev_ctx, DenseTensor Matmul(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
bool transpose_x, bool transpose_x = false,
bool transpose_y) { bool transpose_y = false) {
auto dense_out = Empty<T, Context>(dev_ctx); auto dense_out = Empty<T, Context>(dev_ctx);
MetaTensor meta_out(&dense_out); MetaTensor meta_out(&dense_out);
MatmulInferMeta(x, y, transpose_x, transpose_y, &meta_out); MatmulInferMeta(x, y, transpose_x, transpose_y, &meta_out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册