未验证 提交 da441363 编写于 作者: C crystal 提交者: GitHub

FixEighOP; Unified MatrixEighFunctor function (#35812)

上级 a1b6ae26
...@@ -47,13 +47,10 @@ class EighOp : public framework::OperatorWithKernel { ...@@ -47,13 +47,10 @@ class EighOp : public framework::OperatorWithKernel {
input_dim[rank - 2], input_dim[rank - 1])); input_dim[rank - 2], input_dim[rank - 1]));
std::vector<int64_t> values_dim; std::vector<int64_t> values_dim;
if (rank > 2) {
for (auto i = 0; i < rank - 1; i++) { for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]); values_dim.emplace_back(input_dim[i]);
} }
} else {
values_dim = {input_dim[1]};
}
ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim)); ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim));
ctx->SetOutputDim("Eigenvectors", input_dim); ctx->SetOutputDim("Eigenvectors", input_dim);
...@@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel { ...@@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel {
"EighGrad"); "EighGrad");
OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors",
"EighGrad"); "EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input", "Eigenvalues@GRAD", "EighGrad"); "Input", "Eigenvalues@GRAD", "EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")), OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")),
"Input", "Eigenvectors@GRAD", "EighGrad"); "Input", "Eigenvectors@GRAD", "EighGrad");
auto dims = ctx->GetInputDim("Eigenvectors"); auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
...@@ -14,34 +14,14 @@ limitations under the License. */ ...@@ -14,34 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h" #include "paddle/fluid/operators/eigh_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename ValueType, typename T>
class EighGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctor<ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighGPUKernel<float, float>, ops::EighGPUKernel<double, double>, eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighGPUKernel<float, paddle::platform::complex<float>>, ops::EighKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighGPUKernel<double, paddle::platform::complex<double>>); ops::EighKernel<paddle::platform::CUDADeviceContext, float,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
eigh_grad, eigh_grad,
......
...@@ -22,24 +22,17 @@ namespace operators { ...@@ -22,24 +22,17 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename ValueType, typename T> template <typename DeviceContext, typename ValueType, typename T>
class EighKernel : public framework::OpKernel<T> { class EighKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto input_var = ctx.Input<Tensor>("X"); auto input = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues"); auto output_w = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors"); auto output_v = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO"); std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L"); bool is_lower = (lower == "L");
math::MatrixEighFunctorCPU<DeviceContext, ValueType, T> functor; math::MatrixEighFunctor<DeviceContext, ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true); functor(ctx, *input, output_w, output_v, is_lower, true);
} }
}; };
...@@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel<T> { ...@@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace()); x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w_var = *ctx.Input<Tensor>("Eigenvalues"); auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
auto& output_v_var = *ctx.Input<Tensor>("Eigenvectors"); auto& output_v = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad = auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues")); *ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& output_v_grad = auto& output_v_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors")); *ctx.Input<Tensor>(framework::GradVarName("Eigenvectors"));
auto& dims = output_v_var.dims(); auto& dims = output_v.dims();
const int m = dims[dims.size() - 1]; const int m = dims[dims.size() - 1];
auto dito = auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>( math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx); ctx);
auto tV = dito.Transpose(dito.Conj(output_v_var)); auto tV = dito.Transpose(dito.Conj(output_v));
auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2), auto W = dito.template Sub<ValueType>(dito.Unsqueeze(output_w, -2),
dito.Unsqueeze(output_w_var, -1)); dito.Unsqueeze(output_w, -1));
Tensor result = dito.Matmul(tV, output_v_grad); Tensor result = dito.Matmul(tV, output_v_grad);
result.mutable_data<T>(dims, ctx.GetPlace()); result.mutable_data<T>(dims, ctx.GetPlace());
std::vector<int> out_shape = framework::vectorize<int>(dims); std::vector<int> out_shape = framework::vectorize<int>(dims);
auto constant = dito.Fill(out_shape, 0.5); auto constant = dito.Fill(out_shape, 0.5);
result = dito.Sub(result, dito.Conj(dito.Transpose(result))); result = dito.Sub(result, dito.Conj(dito.Transpose(result)));
result = dito.Mul(result, constant); result = dito.Mul(result, constant);
result = dito.Div_(result, W); result = dito.Div(result, W);
result = dito.DiagFill(m, m, m, 0, output_w_grad, result); result = dito.DiagFill(m, m, m, 0, output_w_grad, result);
x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV)); x_grad = dito.Matmul(output_v, dito.Matmul(result, tV));
} }
}; };
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "Eigen/Core" #include "Eigen/Core"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/fluid/platform/dynload/cusolver.h"
...@@ -26,10 +25,6 @@ namespace paddle { ...@@ -26,10 +25,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using InputMatrixMap = Eigen::Map< using InputMatrixMap = Eigen::Map<
...@@ -67,7 +62,7 @@ inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data, ...@@ -67,7 +62,7 @@ inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data,
eigenvalues = eigen_solver.eigenvalues().transpose(); eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) { if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose(); eigenvectors = eigen_solver.eigenvectors();
} }
} }
} }
...@@ -103,7 +98,7 @@ inline void ComputeComplexEigenvaluesAndVectors(T *x_data, ...@@ -103,7 +98,7 @@ inline void ComputeComplexEigenvaluesAndVectors(T *x_data,
eigenvalues = eigen_solver.eigenvalues().transpose(); eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) { if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose(); eigenvectors = eigen_solver.eigenvectors();
} }
} }
} }
...@@ -117,11 +112,18 @@ inline int64_t GetBatchSize(framework::DDim dims) { ...@@ -117,11 +112,18 @@ inline int64_t GetBatchSize(framework::DDim dims) {
return batch_size; return batch_size;
} }
template <typename DeviceContext, typename ValueType, typename T>
struct MatrixEighFunctor {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors);
};
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real // Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to // symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors. // control whether to return the eigenvectors.
template <typename DeviceContext, typename ValueType, typename T> template <typename ValueType, typename T>
struct MatrixEighFunctorCPU { struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
public: public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input, void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
...@@ -134,7 +136,8 @@ struct MatrixEighFunctorCPU { ...@@ -134,7 +136,8 @@ struct MatrixEighFunctorCPU {
for (int64_t i = 0; i < dim_size - 2; i++) { for (int64_t i = 0; i < dim_size - 2; i++) {
batch_size *= dims[i]; batch_size *= dims[i];
} }
auto dito = DeviceIndependenceTensorOperations<DeviceContext, T>(ctx); auto dito =
DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(ctx);
Tensor input_tensor; Tensor input_tensor;
TensorCopy(input, ctx.GetPlace(), &input_tensor); TensorCopy(input, ctx.GetPlace(), &input_tensor);
if (!is_lower) { if (!is_lower) {
...@@ -157,9 +160,6 @@ struct MatrixEighFunctorCPU { ...@@ -157,9 +160,6 @@ struct MatrixEighFunctorCPU {
ComputeFloatEigenvaluesAndVectors<ValueType>( ComputeFloatEigenvaluesAndVectors<ValueType>(
x_data, value_data, vector_data, batch_size, rows, rows, has_vectors); x_data, value_data, vector_data, batch_size, rows, rows, has_vectors);
} }
if (has_vectors) {
*eigen_vectors = dito.Transpose(*eigen_vectors);
}
} }
}; };
...@@ -169,7 +169,7 @@ struct MatrixEighFunctorCPU { ...@@ -169,7 +169,7 @@ struct MatrixEighFunctorCPU {
// symmetric matrices on GPU, and uses the variable has_vectors // symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors. // to control whether to return the eigenvectors.
template <typename ValueType, typename T> template <typename ValueType, typename T>
struct MatrixEighFunctor { struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
public: public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input, void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower, Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
...@@ -278,7 +278,8 @@ struct MatrixEighFunctor { ...@@ -278,7 +278,8 @@ struct MatrixEighFunctor {
#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \ #define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \
template <> \ template <> \
inline void MatrixEighFunctor<ValueType, T>::EvdBuffer( \ inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::EvdBuffer( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \ cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
int *lwork) const { \ int *lwork) const { \
...@@ -292,7 +293,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE); ...@@ -292,7 +293,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);
#define EVD_INSTANCE(ValueType, T, C, CastType) \ #define EVD_INSTANCE(ValueType, T, C, CastType) \
template <> \ template <> \
inline void MatrixEighFunctor<ValueType, T>::Evd( \ inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::Evd( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \ cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \ cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
int lwork, int *devInfo) const { \ int lwork, int *devInfo) const { \
......
...@@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations { ...@@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations {
framework::Tensor Div(const framework::Tensor& x, framework::Tensor Div(const framework::Tensor& x,
const framework::Tensor& y) { const framework::Tensor& y) {
framework::Tensor ret; framework::Tensor ret;
if (x.type() != y.type()) {
ret.mutable_data<T>(x.dims(), context.GetPlace());
auto x_vector = EigenVector<T>::Flatten(x);
auto y_vector = EigenVector<ValueType>::Flatten(y);
auto out_vector = EigenVector<T>::Flatten(ret);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_vector.device(place) = x_vector / y_vector;
} else {
std::vector<int> out_shape = GetBroadcastShape({&x, &y}); std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape)); ret.Resize(framework::make_ddim(out_shape));
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, DivFunctor<T>(), &ret); context, &x, &y, -1, DivFunctor<T>(), &ret);
}
return ret; return ret;
} }
framework::Tensor Add(const framework::Tensor& x, framework::Tensor Add(const framework::Tensor& x,
...@@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations { ...@@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations {
NameInTensorMap inputs({{"X", {&x}}}); NameInTensorMap inputs({{"X", {&x}}});
return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim); return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim);
} }
// Support float and complex type subtraction,the default is T type
template <typename InT = T>
framework::Tensor Sub(const framework::Tensor& x, framework::Tensor Sub(const framework::Tensor& x,
const framework::Tensor& y) { const framework::Tensor& y) {
framework::Tensor ret; framework::Tensor ret;
...@@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations { ...@@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations {
#if defined(__NVCC__) || defined(__HIPCC__) #if defined(__NVCC__) || defined(__HIPCC__)
// For GPU, there is no need to define XxxInverseFunctor and call // For GPU, there is no need to define XxxInverseFunctor and call
// ElementwiseComputeEx in two branches. // ElementwiseComputeEx in two branches.
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, SubFunctor<T>(), &ret); context, &x, &y, -1, SubFunctor<InT>(), &ret);
#endif #endif
} else { } else {
if (x.dims().size() >= y.dims().size()) { if (x.dims().size() >= y.dims().size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, SubFunctor<T>(), &ret); context, &x, &y, -1, SubFunctor<InT>(), &ret);
} else { } else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
// This is copyed from elementwise_sub, which means we // This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank // need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<T>(), &ret); ElementwiseComputeEx<InverseSubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, InverseSubFunctor<InT>(), &ret);
} }
} }
return ret; return ret;
...@@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations { ...@@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations {
return out; return out;
} }
// Support x and y are different data types
Tensor Div_(const Tensor& x, const Tensor& y) {
Tensor out;
out.mutable_data<T>(x.dims(), context.GetPlace());
auto x_vector = EigenVector<T>::Flatten(x);
auto y_vector = EigenVector<ValueType>::Flatten(y);
auto out_vector = EigenVector<T>::Flatten(out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_vector.device(place) = x_vector / y_vector;
return out;
}
framework::Tensor Sub_(const framework::Tensor& x,
const framework::Tensor& y) {
framework::Tensor ret;
std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape));
if (x.dims().size() >= y.dims().size()) {
ElementwiseComputeEx<SubFunctor<ValueType>, DeviceContext, ValueType>(
context, &x, &y, -1, SubFunctor<ValueType>(), &ret);
} else {
ElementwiseComputeEx<InverseSubFunctor<ValueType>, DeviceContext,
ValueType>(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<ValueType>(), &ret);
}
return ret;
}
private: private:
const framework::ExecutionContext& context; const framework::ExecutionContext& context;
BlasT<DeviceContext, T> GetBlas() { BlasT<DeviceContext, T> GetBlas() {
......
...@@ -140,7 +140,7 @@ class TestEighAPI(unittest.TestCase): ...@@ -140,7 +140,7 @@ class TestEighAPI(unittest.TestCase):
self.check_static_complex_result() self.check_static_complex_result()
def test_in_dynamic_mode(self): def test_in_dynamic_mode(self):
paddle.disable_static(self.place) paddle.disable_static()
input_real_data = paddle.to_tensor(self.real_data) input_real_data = paddle.to_tensor(self.real_data)
expected_w, expected_v = np.linalg.eigh(self.real_data) expected_w, expected_v = np.linalg.eigh(self.real_data)
actual_w, actual_v = paddle.linalg.eigh(input_real_data) actual_w, actual_v = paddle.linalg.eigh(input_real_data)
...@@ -152,7 +152,7 @@ class TestEighAPI(unittest.TestCase): ...@@ -152,7 +152,7 @@ class TestEighAPI(unittest.TestCase):
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v) self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)
def test_eigh_grad(self): def test_eigh_grad(self):
paddle.disable_static(self.place) paddle.disable_static()
x = paddle.to_tensor(self.complex_data, stop_gradient=False) x = paddle.to_tensor(self.complex_data, stop_gradient=False)
w, v = paddle.linalg.eigh(x) w, v = paddle.linalg.eigh(x)
(w.sum() + paddle.abs(v).sum()).backward() (w.sum() + paddle.abs(v).sum()).backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册