未验证 提交 ccf99b66 编写于 作者: H Haohongxiang 提交者: GitHub

Add cpu kernel of new api : lstsq (#38585)

* add cpu kernel of lstsq

* update

* modify code style

* modify unittest

* remove support for complex
上级 667dc9f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/lstsq_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class LstsqOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("SingularValues"), "Output", "SingularValues",
"LstsqOp");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_rank = x_dims.size();
int y_rank = y_dims.size();
PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(y_rank, 2,
platform::errors::InvalidArgument(
"Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));
PADDLE_ENFORCE_EQ(
x_rank, y_rank,
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank, y_rank));
std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i], y_dims[i],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i], y_dims[i], i));
batch_dims_vec.emplace_back(x_dims[i]);
}
PADDLE_ENFORCE_EQ(
x_dims[x_rank - 2], y_dims[y_rank - 2],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
x_dims[x_rank - 2], y_dims[y_rank - 2]));
ctx->SetOutputDim("Rank", framework::make_ddim(batch_dims_vec));
batch_dims_vec.emplace_back(
std::min(x_dims[x_rank - 2], x_dims[x_rank - 1]));
ctx->SetOutputDim("SingularValues", framework::make_ddim(batch_dims_vec));
batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1];
batch_dims_vec.emplace_back(y_dims[x_rank - 1]);
ctx->SetOutputDim("Solution", framework::make_ddim(batch_dims_vec));
}
protected:
// The output of lstsq is always complex-valued even for real-valued inputs
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported data type: %s!", dtype));
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
class LstsqOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A real-valued tensor with shape (*, m, n). "
"The accepted datatype is one of float32, float64");
AddInput("Y",
"(Tensor), A real-valued tensor with shape (*, m, k). "
"The accepted datatype is one of float32, float64");
AddAttr<float>(
"rcond",
"(float, default 0.0), A float value used to determine the effective "
"rank of A.")
.SetDefault(0.0f);
AddAttr<std::string>("driver",
"(string, default \"gels\"). "
"name of the LAPACK method to be used.")
.SetDefault("gels");
AddOutput("Solution",
"(Tensor), The output Solution tensor with shape (*, n, k).");
AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*).");
AddOutput(
"SingularValues",
"(Tensor), The output SingularValues tensor with shape (*, min(m,n)).");
AddComment(R"DOC(
Lstsq Operator.
This API processes Lstsq functor for general matrices.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker)
REGISTER_OP_CPU_KERNEL(
lstsq, ops::LstsqCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::LstsqCPUKernel<paddle::platform::CPUDeviceContext, double>);
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/for_range.h"
#define EPSILON 1e-6
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };
using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim) {
auto x_vec = vectorize(x_dim);
return framework::make_ddim(x_vec);
}
template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using ValueType = math::Real<T>;
const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto rcond = context.Attr<float>("rcond");
auto driver_string = context.Attr<std::string>("driver");
static auto driver_type = std::unordered_map<std::string, LapackDriverType>(
{{"gels", LapackDriverType::Gels},
{"gelsy", LapackDriverType::Gelsy},
{"gelsd", LapackDriverType::Gelsd},
{"gelss", LapackDriverType::Gelss}});
auto driver = driver_type[driver_string];
auto solution = context.Output<Tensor>("Solution");
auto* rank = context.Output<Tensor>("Rank");
auto* singular_values = context.Output<Tensor>("SingularValues");
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);
auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int batch_count = BatchCount(x);
auto ori_solution_dim = solution->dims();
int ori_solu_stride = MatrixStride(*solution);
// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
int info = 0;
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int lda = std::max<int>(m, 1);
int ldb = std::max<int>(1, std::max(m, n));
Tensor new_x;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
solution->mutable_data<T>(
context.GetPlace(),
size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), solution);
if (m < n) solution->Resize(UDDim(ori_solution_dim));
Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution);
framework::TensorCopy(input_x_trans, new_x.place(), &new_x);
framework::TensorCopy(input_y_trans, solution->place(), solution);
auto* x_vector = new_x.data<T>();
auto* y_vector = solution->data<T>();
// "gels" divers does not need to compute rank
int rank_32 = 0;
int* rank_data = nullptr;
int* rank_working_ptr = nullptr;
if (driver != LapackDriverType::Gels) {
rank_data = rank->mutable_data<int>(context.GetPlace());
rank_working_ptr = rank_data;
}
// "gelsd" and "gelss" divers need to compute singular values
ValueType* s_data = nullptr;
ValueType* s_working_ptr = nullptr;
int s_stride = 0;
if (driver == LapackDriverType::Gelsd ||
driver == LapackDriverType::Gelss) {
s_data = singular_values->mutable_data<ValueType>(context.GetPlace());
s_working_ptr = s_data;
auto s_dims = singular_values->dims();
s_stride = s_dims[s_dims.size() - 1];
}
// "jpvt" is only used for "gelsy" driver
Tensor jpvt;
int* jpvt_data = nullptr;
if (driver == LapackDriverType::Gelsy) {
jpvt.Resize(framework::make_ddim({std::max<int>(1, n)}));
jpvt_data = jpvt.mutable_data<int>(context.GetPlace());
}
// run once the driver, first to get the optimal workspace size
int lwork = -1;
T wkopt;
ValueType rwkopt;
int iwkopt = 0;
if (driver == LapackDriverType::Gels) {
math::lapackGels('N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt,
lwork, &info);
} else if (driver == LapackDriverType::Gelsd) {
math::lapackGelsd(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &iwkopt, &info);
} else if (driver == LapackDriverType::Gelsy) {
math::lapackGelsy(m, n, nrhs, x_vector, lda, y_vector, ldb, jpvt_data,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &info);
} else if (driver == LapackDriverType::Gelss) {
math::lapackGelss(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &info);
}
lwork = std::max<int>(1, static_cast<int>(math::Real<T>(wkopt)));
Tensor work;
work.Resize(framework::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
// "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers
Tensor rwork;
ValueType* rwork_data = nullptr;
if (framework::IsComplexType(x.type()) &&
driver != LapackDriverType::Gels) {
int rwork_len = 0;
if (driver == LapackDriverType::Gelsy) {
rwork_len = std::max<int>(1, 2 * n);
} else if (driver == LapackDriverType::Gelss) {
rwork_len = std::max<int>(1, 5 * std::min(m, n));
} else if (driver == LapackDriverType::Gelsd) {
rwork_len = std::max<int>(1, rwkopt);
}
rwork.Resize(framework::make_ddim({rwork_len}));
rwork_data = rwork.mutable_data<ValueType>(context.GetPlace());
}
// "iwork" workspace array is relavant only for "gelsd" driver
Tensor iwork;
int* iwork_data = nullptr;
if (driver == LapackDriverType::Gelsd) {
iwork.Resize(framework::make_ddim({std::max<int>(1, iwkopt)}));
iwork_data = iwork.mutable_data<int>(context.GetPlace());
}
int solu_stride = std::max(y_stride, ori_solu_stride);
for (auto i = 0; i < batch_count; ++i) {
auto* x_input = &x_vector[i * x_stride];
auto* y_input = &y_vector[i * solu_stride];
rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;
if (driver == LapackDriverType::Gels) {
math::lapackGels('N', m, n, nrhs, x_input, lda, y_input, ldb, work_data,
lwork, &info);
} else if (driver == LapackDriverType::Gelsd) {
math::lapackGelsd(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, iwork_data, &info);
} else if (driver == LapackDriverType::Gelsy) {
math::lapackGelsy(m, n, nrhs, x_input, lda, y_input, ldb, jpvt_data,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, &info);
} else if (driver == LapackDriverType::Gelss) {
math::lapackGelss(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, &info);
}
PADDLE_ENFORCE_EQ(
info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: Lapack info is not zero but [%d]", i, info));
if (rank_working_ptr) *rank_working_ptr = static_cast<int>(rank_32);
}
Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, solution->place(), solution);
if (m >= n) solution->Resize(UDDim(ori_solution_dim));
}
};
} // namespace operators
} // namespace paddle
......@@ -125,6 +125,70 @@ void lapackEig<platform::complex<float>, float>(
reinterpret_cast<std::complex<float> *>(work), &lwork, rwork, info);
}
template <>
void lapackGels<double>(char trans, int m, int n, int nrhs, double *a, int lda,
double *b, int ldb, double *work, int lwork,
int *info) {
platform::dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work,
&lwork, info);
}
template <>
void lapackGels<float>(char trans, int m, int n, int nrhs, float *a, int lda,
float *b, int ldb, float *work, int lwork, int *info) {
platform::dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work,
&lwork, info);
}
template <>
void lapackGelsd<double>(int m, int n, int nrhs, double *a, int lda, double *b,
int ldb, double *s, double rcond, int *rank,
double *work, int lwork, double *rwork, int *iwork,
int *info) {
platform::dynload::dgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank,
work, &lwork, iwork, info);
}
template <>
void lapackGelsd<float>(int m, int n, int nrhs, float *a, int lda, float *b,
int ldb, float *s, float rcond, int *rank, float *work,
int lwork, float *rwork, int *iwork, int *info) {
platform::dynload::sgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank,
work, &lwork, iwork, info);
}
template <>
void lapackGelsy<double>(int m, int n, int nrhs, double *a, int lda, double *b,
int ldb, int *jpvt, double rcond, int *rank,
double *work, int lwork, double *rwork, int *info) {
platform::dynload::dgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond,
rank, work, &lwork, info);
}
template <>
void lapackGelsy<float>(int m, int n, int nrhs, float *a, int lda, float *b,
int ldb, int *jpvt, float rcond, int *rank, float *work,
int lwork, float *rwork, int *info) {
platform::dynload::sgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond,
rank, work, &lwork, info);
}
template <>
void lapackGelss<double>(int m, int n, int nrhs, double *a, int lda, double *b,
int ldb, double *s, double rcond, int *rank,
double *work, int lwork, double *rwork, int *info) {
platform::dynload::dgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank,
work, &lwork, info);
}
template <>
void lapackGelss<float>(int m, int n, int nrhs, float *a, int lda, float *b,
int ldb, float *s, float rcond, int *rank, float *work,
int lwork, float *rwork, int *info) {
platform::dynload::sgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank,
work, &lwork, info);
}
template <>
void lapackCholeskySolve<platform::complex<double>>(
char uplo, int n, int nrhs, platform::complex<double> *a, int lda,
......
......@@ -20,21 +20,46 @@ namespace math {
// LU (for example)
template <typename T>
void lapackLu(int m, int n, T* a, int lda, int* ipiv, int* info);
void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info);
// Eigh
template <typename T, typename ValueType = T>
void lapackEigh(char jobz, char uplo, int n, T* a, int lda, ValueType* w,
T* work, int lwork, ValueType* rwork, int lrwork, int* iwork,
int liwork, int* info);
void lapackEigh(char jobz, char uplo, int n, T *a, int lda, ValueType *w,
T *work, int lwork, ValueType *rwork, int lrwork, int *iwork,
int liwork, int *info);
// Eig
template <typename T1, typename T2 = T1>
void lapackEig(char jobvl, char jobvr, int n, T1* a, int lda, T1* w, T1* vl,
int ldvl, T1* vr, int ldvr, T1* work, int lwork, T2* rwork,
int* info);
void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl,
int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork,
int *info);
// Gels
template <typename T>
void lapackCholeskySolve(char uplo, int n, int nrhs, T* a, int lda, T* b,
int ldb, int* info);
void lapackGels(char trans, int m, int n, int nrhs, T *a, int lda, T *b,
int ldb, T *work, int lwork, int *info);
// Gelsd
template <typename T1, typename T2>
void lapackGelsd(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s,
T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork,
int *iwork, int *info);
// Gelsy
template <typename T1, typename T2>
void lapackGelsy(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb,
int *jpvt, T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork,
int *info);
// Gelss
template <typename T1, typename T2>
void lapackGelss(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s,
T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork,
int *info);
template <typename T>
void lapackCholeskySolve(char uplo, int n, int nrhs, T *a, int lda, T *b,
int ldb, int *info);
} // namespace math
} // namespace operators
......
......@@ -66,6 +66,39 @@ extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex<float> *a,
std::complex<float> *work, int *lwork, float *rwork,
int *info);
// gels
extern "C" void dgels_(char *trans, int *m, int *n, int *nrhs, double *a,
int *lda, double *b, int *ldb, double *work, int *lwork,
int *info);
extern "C" void sgels_(char *trans, int *m, int *n, int *nrhs, float *a,
int *lda, float *b, int *ldb, float *work, int *lwork,
int *info);
// gelsd
extern "C" void dgelsd_(int *m, int *n, int *nrhs, double *a, int *lda,
double *b, int *ldb, double *s, double *rcond,
int *rank, double *work, int *lwork, int *iwork,
int *info);
extern "C" void sgelsd_(int *m, int *n, int *nrhs, float *a, int *lda, float *b,
int *ldb, float *s, float *rcond, int *rank,
float *work, int *lwork, int *iwork, int *info);
// gelsy
extern "C" void dgelsy_(int *m, int *n, int *nrhs, double *a, int *lda,
double *b, int *ldb, int *jpvt, double *rcond,
int *rank, double *work, int *lwork, int *info);
extern "C" void sgelsy_(int *m, int *n, int *nrhs, float *a, int *lda, float *b,
int *ldb, int *jpvt, float *rcond, int *rank,
float *work, int *lwork, int *info);
// gelss
extern "C" void dgelss_(int *m, int *n, int *nrhs, double *a, int *lda,
double *b, int *ldb, double *s, double *rcond,
int *rank, double *work, int *lwork, int *info);
extern "C" void sgelss_(int *m, int *n, int *nrhs, float *a, int *lda, float *b,
int *ldb, float *s, float *rcond, int *rank,
float *work, int *lwork, int *info);
extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a,
int *lda, std::complex<double> *b, int *ldb, int *info);
extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a,
......@@ -115,6 +148,14 @@ extern void *lapack_dso_handle;
__macro(sgeev_); \
__macro(zgeev_); \
__macro(cgeev_); \
__macro(dgels_); \
__macro(sgels_); \
__macro(dgelsd_); \
__macro(sgelsd_); \
__macro(dgelsy_); \
__macro(sgelsy_); \
__macro(dgelss_); \
__macro(sgelss_); \
__macro(zpotrs_); \
__macro(cpotrs_); \
__macro(dpotrs_); \
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
class LinalgLstsqTestCase(unittest.TestCase):
def setUp(self):
self.init_config()
self.generate_input()
self.generate_output()
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelsd"
self._input_shape_1 = (5, 4)
self._input_shape_2 = (5, 3)
def generate_input(self):
self._input_data_1 = np.random.random(self._input_shape_1).astype(
self.dtype)
self._input_data_2 = np.random.random(self._input_shape_2).astype(
self.dtype)
def generate_output(self):
if len(self._input_shape_1) == 2:
out = np.linalg.lstsq(
self._input_data_1, self._input_data_2, rcond=self.rcond)
elif len(self._input_shape_1) == 3:
out = np.linalg.lstsq(
self._input_data_1[0], self._input_data_2[0], rcond=self.rcond)
self._output_solution = out[0]
self._output_residuals = out[1]
self._output_rank = out[2]
self._output_sg_values = out[3]
def test_dygraph(self):
paddle.disable_static()
paddle.device.set_device("cpu")
place = paddle.CPUPlace()
x = paddle.to_tensor(self._input_data_1, place=place, dtype=self.dtype)
y = paddle.to_tensor(self._input_data_2, place=place, dtype=self.dtype)
results = paddle.linalg.lstsq(
x, y, rcond=self.rcond, driver=self.driver)
res_solution = results[0].numpy()
res_residuals = results[1].numpy()
res_rank = results[2].numpy()
res_singular_values = results[3].numpy()
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]:
if (np.abs(res_residuals - self._output_residuals) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ residuals dygraph Failed")
if self.driver in ("gelsy", "gelsd", "gelss"):
if (np.abs(res_rank - self._output_rank) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ rank dygraph Failed")
if self.driver in ("gelsd", "gelss"):
if (np.abs(res_singular_values - self._output_sg_values) < 1e-6
).any():
pass
else:
raise RuntimeError("Check LSTSQ singular values dygraph Failed")
def test_static(self):
paddle.enable_static()
paddle.device.set_device("cpu")
place = fluid.CPUPlace()
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = paddle.fluid.data(
name="x",
shape=self._input_shape_1,
dtype=self._input_data_1.dtype)
y = paddle.fluid.data(
name="y",
shape=self._input_shape_2,
dtype=self._input_data_2.dtype)
results = paddle.linalg.lstsq(
x, y, rcond=self.rcond, driver=self.driver)
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"x": self._input_data_1,
"y": self._input_data_2},
fetch_list=[results])
if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]:
if (np.abs(fetches[1] - self._output_residuals) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ residuals static Failed")
if self.driver in ("gelsy", "gelsd", "gelss"):
if (np.abs(fetches[2] - self._output_rank) < 1e-6).any():
pass
else:
raise RuntimeError("Check LSTSQ rank static Failed")
if self.driver in ("gelsd", "gelss"):
if (np.abs(fetches[3] - self._output_sg_values) < 1e-6).any():
pass
else:
raise RuntimeError(
"Check LSTSQ singular values static Failed")
class LinalgLstsqTestCase(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (5, 10)
self._input_shape_2 = (5, 5)
class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 0.1
self.driver = "gels"
self._input_shape_1 = (3, 2)
self._input_shape_2 = (3, 3)
class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.driver = "gels"
self._input_shape_1 = (10, 5)
self._input_shape_2 = (10, 2)
class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelss"
self._input_shape_1 = (5, 5)
self._input_shape_2 = (5, 1)
class LinalgLstsqTestCaseGelsyFloat32(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.driver = "gelsy"
self._input_shape_1 = (8, 2)
self._input_shape_2 = (8, 10)
class LinalgLstsqTestCaseBatch1(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.driver = None
self._input_shape_1 = (2, 3, 10)
self._input_shape_2 = (2, 3, 4)
class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelss"
self._input_shape_1 = (2, 8, 6)
self._input_shape_2 = (2, 8, 2)
class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float64'
self.rcond = 1e-15
self.driver = "gelsd"
self._input_shape_1 = (200, 100)
self._input_shape_2 = (200, 50)
class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase):
def init_config(self):
self.dtype = 'float32'
self.rcond = 1e-15
self.driver = "gelss"
self._input_shape_1 = (50, 600)
self._input_shape_2 = (50, 300)
if __name__ == '__main__':
unittest.main()
......@@ -32,6 +32,7 @@ from .tensor.linalg import det # noqa: F401
from .tensor.linalg import slogdet # noqa: F401
from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import triangular_solve # noqa: F401
from .tensor.linalg import lstsq
__all__ = [
'cholesky', #noqa
......@@ -54,4 +55,5 @@ __all__ = [
'solve',
'cholesky_solve',
'triangular_solve',
'lstsq'
]
......@@ -43,6 +43,7 @@ from .linalg import cov # noqa: F401
from .linalg import norm # noqa: F401
from .linalg import cond # noqa: F401
from .linalg import transpose # noqa: F401
from .linalg import lstsq # noqa: F401
from .linalg import dist # noqa: F401
from .linalg import t # noqa: F401
from .linalg import cross # noqa: F401
......@@ -268,6 +269,7 @@ tensor_method_func = [ #noqa
'norm',
'cond',
'transpose',
'lstsq',
'dist',
't',
'cross',
......
......@@ -23,7 +23,6 @@ import paddle
from paddle.common_ops_import import core
from paddle.common_ops_import import VarDesc
from paddle import _C_ops
import paddle
__all__ = []
......@@ -2616,3 +2615,107 @@ def eigvalsh(x, UPLO='L', name=None):
attrs={'UPLO': UPLO,
'is_test': is_test})
return out_value
def lstsq(x, y, rcond=1e-15, driver=None, name=None):
device = paddle.device.get_device()
if device == "cpu":
if driver not in (None, "gels", "gelss", "gelsd", "gelsy"):
raise ValueError(
"Only support valid driver is 'gels', 'gelss', 'gelsd', 'gelsy' or None for CPU inputs. But got {}".
format(driver))
driver = "gelsy" if driver is None else driver
elif "gpu" in device:
if driver not in (None, "gels"):
raise ValueError(
"Only support valid driver is 'gels' or None for CUDA inputs. But got {}".
format(driver))
driver = "gels" if driver is None else driver
else:
raise RuntimeError("Only support lstsq api for CPU or CUDA device.")
if in_dygraph_mode():
solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond,
"driver", driver)
if x.shape[-2] > x.shape[-1]:
matmul_out = _varbase_creator(dtype=x.dtype)
_C_ops.matmul(x, solution, matmul_out, 'trans_x', False, 'trans_y',
False)
minus_out = _C_ops.elementwise_sub(matmul_out, y)
pow_out = _C_ops.pow(minus_out, 'factor', 2)
residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim',
False, 'reduce_all', False)
else:
residuals = paddle.empty(shape=[0], dtype=x.dtype)
if driver == "gels":
rank = paddle.empty(shape=[0], dtype=paddle.int32)
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
elif driver == "gelsy":
singular_values = paddle.empty(shape=[0], dtype=x.dtype)
return solution, residuals, rank, singular_values
helper = LayerHelper('lstsq', **locals())
check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'lstsq')
check_variable_and_dtype(
y, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'lstsq')
solution = helper.create_variable_for_type_inference(dtype=x.dtype)
residuals = helper.create_variable_for_type_inference(dtype=x.dtype)
rank = helper.create_variable_for_type_inference(dtype=paddle.int32)
singular_values = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='lstsq',
inputs={'X': x,
'Y': y},
outputs={
'Solution': solution,
'Rank': rank,
'SingularValues': singular_values
},
attrs={'rcond': rcond,
'driver': driver})
matmul_out = helper.create_variable_for_type_inference(dtype=x.dtype)
minus_out = helper.create_variable_for_type_inference(dtype=x.dtype)
pow_out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='matmul_v2',
inputs={'X': x,
'Y': solution},
outputs={'Out': matmul_out},
attrs={
'trans_x': False,
'trans_y': False,
})
helper.append_op(
type='elementwise_sub',
inputs={'X': matmul_out,
'Y': y},
outputs={'Out': minus_out})
helper.append_op(
type='pow',
inputs={'X': minus_out},
outputs={'Out': pow_out},
attrs={'factor': 2})
helper.append_op(
type='reduce_sum',
inputs={'X': pow_out},
outputs={'Out': residuals},
attrs={'dim': [-2],
'keep_dim': False,
'reduce_all': False})
if driver == "gels":
rank = paddle.static.data(name='rank', shape=[0])
singular_values = paddle.static.data(name='singular_values', shape=[0])
elif driver == "gelsy":
singular_values = paddle.static.data(name='singular_values', shape=[0])
return solution, residuals, rank, singular_values
......@@ -1088,6 +1088,7 @@ SIXTH_PARALLEL_JOB_NEW = [
'test_fill_any_op',
'test_frame_op',
'test_linalg_pinv_op',
'test_linalg_lstsq_op',
'test_gumbel_softmax_op',
'test_matrix_power_op',
'test_multi_dot_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册