未验证 提交 6b4f2fbf 编写于 作者: W Weilong Wu 提交者: GitHub

[Cherry-Pick]Add paddle.linalg.solve OP (#35715) (#36056)

This PR supports linalg.solve calculation for linear algorithm module of Paddle. One may call paddle.linalg.solve to use it.
上级 df81915a
...@@ -123,7 +123,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project ...@@ -123,7 +123,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling segment_pooling executor device_memory_aligment generator) sequence_pooling segment_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function)
......
...@@ -89,6 +89,7 @@ math_library(bert_encoder_functor) ...@@ -89,6 +89,7 @@ math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function) math_library(tree2col DEPS math_function)
math_library(matrix_inverse) math_library(matrix_inverse)
math_library(segment_pooling) math_library(segment_pooling)
math_library(matrix_solve)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
......
...@@ -247,6 +247,12 @@ class Blas { ...@@ -247,6 +247,12 @@ class Blas {
template <typename T> template <typename T>
void BatchedMatInv(int n, const T** a, T** a_inv, int* info, void BatchedMatInv(int n, const T** a, T** a_inv, int* info,
int batch_size) const; int batch_size) const;
// cuBlas solve
template <typename T>
void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
int lda, int* ipiv, T** b, int ldb, int* info,
int batch_size) const;
#endif #endif
private: private:
...@@ -402,6 +408,12 @@ class BlasT : private Blas<DeviceContext> { ...@@ -402,6 +408,12 @@ class BlasT : private Blas<DeviceContext> {
void BatchedMatInv(ARGS... args) const { void BatchedMatInv(ARGS... args) const {
Base()->template BatchedMatInv<T>(args...); Base()->template BatchedMatInv<T>(args...);
} }
// solve
template <typename... ARGS>
void BatchedGETRS(ARGS... args) const {
Base()->template BatchedGETRS<T>(args...);
}
#endif #endif
private: private:
......
...@@ -114,6 +114,12 @@ struct CUBlas<float> { ...@@ -114,6 +114,12 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSmatinvBatched(args...)); platform::dynload::cublasSmatinvBatched(args...));
} }
template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgetrsBatched(args...));
}
}; };
template <> template <>
...@@ -182,6 +188,12 @@ struct CUBlas<double> { ...@@ -182,6 +188,12 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDmatinvBatched(args...)); platform::dynload::cublasDmatinvBatched(args...));
} }
template <typename... ARGS>
static void GETRS_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgetrsBatched(args...));
}
}; };
template <> template <>
...@@ -871,6 +883,20 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a, ...@@ -871,6 +883,20 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
}); });
} }
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTrans =
(trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -717,6 +717,19 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a, ...@@ -717,6 +717,19 @@ void Blas<platform::CUDADeviceContext>::BatchedMatInv(int n, const T **a,
}); });
} }
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGETRS(
CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv,
T **b, int ldb, int *info, int batch_size) const {
rocblas_operation cuTrans = (trans == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info,
batch_size);
});
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
compute_solve_eigen<platform::CPUDeviceContext, T>(dev_ctx, a, b, out);
}
};
template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace platform {
class CUDADeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
class MatrixSolveFunctor;
template <typename T>
class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
#ifndef PADDLE_WITH_HIP
// solve the equation: Ax = B,
// use cuBlas cublas<S/D>getrfBatched funcion to performs the LU
// factorization of each matrix A,
// and then use cuBlas cublas<S/D>getriBatched function to solve the
// equation after LU factorization.
// ref:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
const auto& a_dims = a.dims();
const int a_rank = a_dims.size();
int n = a_dims[a_rank - 1];
int lda = n;
int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;
const auto& b_dims = b.dims();
const int b_rank = b_dims.size();
int nrhs = b_dims[b_rank - 1];
int ldb = b_dims[b_rank - 2];
// make sure the out dims is right
out->Resize(b_dims);
out->mutable_data<T>(context.GetPlace());
// copy input A to a temporary tensor tmp_a,
// LU factorization, written back to original matrix A, so in the beginning,
// it's necessary to create a temporary tensor tmp_a.
Tensor tmp_a(a.type());
tmp_a.Resize(a.dims());
tmp_a.mutable_data<T>(context.GetPlace());
TensorCopy(a, context.GetPlace(), &tmp_a);
// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar.
Tensor tmp_b(b.type());
const auto& new_dims_vec = getNewDimsVec(b_dims);
tmp_b.Resize(framework::make_ddim(new_dims_vec));
tmp_b.mutable_data<T>(context.GetPlace());
math::TransposeNormal<platform::CUDADeviceContext, T> trans;
std::vector<int> new_axis = getNewAxis(b_rank);
trans(context, b, &tmp_b, new_axis);
const T* a_data_in_gpu = tmp_a.data<T>();
const T* b_data_in_gpu = tmp_b.data<T>();
std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data_in_gpu + i * n * n;
cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs;
}
// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());
T** gpu_tmp_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
// Allocate device memory for BatchedGETRF's info and pivots.
int num_ints = n < 32 ? batch_size : batch_size * (n + 1);
memory::allocation::AllocationPtr tmp_gpu_info_data =
memory::Alloc(context, num_ints * sizeof(int));
int* gpu_info_ptr = reinterpret_cast<int*>(tmp_gpu_info_data->ptr());
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
// only for singular checking
std::vector<int> info;
info.resize(batch_size);
int* gpu_pivot_ptr =
reinterpret_cast<int*>(tmp_gpu_info_data->ptr()) + batch_size;
// This function performs the LU factorization of each matrix A by the
// equation A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas.BatchedGETRF(n, reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()),
gpu_pivot_ptr, gpu_info_ptr, batch_size);
// check whether BatchedGETRF is executed successfully or not
memory::Copy(platform::CPUPlace(), info.data(),
BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
gpu_info_ptr, sizeof(int) * batch_size, context.stream());
for (int i = 0; i < batch_size; ++i) {
PADDLE_ENFORCE_EQ(info[i], 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: U(%d, %d) is zero, singular U. "
"Please check the matrix value and change it to a "
"non-singular matrix",
i, info[i], info[i]));
}
// hold the result code from BatchedGETRS
int host_info = 0;
// to solve the equation after LU factorization
CBLAS_TRANSPOSE transA = CblasTrans;
blas.BatchedGETRS(
transA, n, nrhs, reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr()),
lda, gpu_pivot_ptr, gpu_tmp_b_ptrs, ldb, &host_info, batch_size);
// check whether BatchedGETRS is executed successfully or not
PADDLE_ENFORCE_EQ(host_info, 0,
platform::errors::InvalidArgument(
"The [%d]'th argument to cublas*getrsBatched had "
"an illegal value.",
-host_info));
// transpose tmp_b to get the final result in row-major form.
math::TransposeNormal<platform::CUDADeviceContext, T> trans2;
trans2(context, tmp_b, out, new_axis);
#else
compute_solve_eigen<platform::CUDADeviceContext, T>(context, a, b, out);
#endif
}
};
template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
void compute_solve_eigen(const DeviceContext& context,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* out) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
// prepare for a
const auto& a_mat_dims = a.dims();
const int a_rank = a_mat_dims.size();
int n = a_mat_dims[a_rank - 1];
int a_batch_size = a_rank > 2 ? a.numel() / (n * n) : 1;
// prepare for b
const auto& b_mat_dims = b.dims();
const int b_rank = b_mat_dims.size();
int b_h = n;
int b_w = b_mat_dims[b_rank - 1];
int b_batch_size = b_rank > 2 ? b.numel() / (b_h * b_w) : 1;
const T* a_ptr = a.data<T>();
const T* b_ptr = b.data<T>();
out->Resize(b_mat_dims); // make sure the out dims is right
T* out_ptr = out->mutable_data<T>(context.GetPlace());
if (a_batch_size == b_batch_size) {
for (int i = 0; i < a_batch_size; ++i) {
ConstEigenMatrixMap a_mat(a_ptr + i * n * n, n, n);
ConstEigenMatrixMap b_mat(b_ptr + i * b_h * b_w, b_h, b_w);
EigenMatrixMap out_mat(out_ptr + i * b_h * b_w, b_h, b_w);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(a_mat);
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_pivot, static_cast<T>(0),
platform::errors::InvalidArgument("Input is not invertible."));
out_mat.noalias() = lu.solve(b_mat);
}
} else {
PADDLE_ENFORCE_EQ(a_batch_size, b_batch_size,
platform::errors::InvalidArgument(
"All input tensors must have the same rank."));
}
}
template <typename DeviceContext, typename T>
class MatrixSolveFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& a,
const framework::Tensor& b, framework::Tensor* out);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/solve_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class SolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Solve");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
std::vector<int64_t> x_dims_vec =
paddle::framework::vectorize(ctx->GetInputDim("X"));
std::vector<int64_t> y_dims_vec =
paddle::framework::vectorize(ctx->GetInputDim("Y"));
auto x_dims_n = x_dims_vec.size();
auto y_dims_n = y_dims_vec.size();
PADDLE_ENFORCE_GT(x_dims_n, 1,
platform::errors::InvalidArgument(
"The input tensor X's dimensions of SolveOp "
"should be larger than 1. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims_n, x_dims));
PADDLE_ENFORCE_GE(y_dims_n, 1,
platform::errors::InvalidArgument(
"The input tensor Y's dimensions of SolveOp "
"should be larger than or equal 1. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims_n, y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2], x_dims[x_dims_n - 1]));
bool x_broadcasted = false, y_broadcasted = false;
bool trans_x = false, trans_y = false;
if (x_dims_n == 1) {
x_dims_vec.insert(x_dims_vec.begin(), 1);
x_dims_n = 2;
x_broadcasted = true;
}
if (y_dims_n == 1) {
y_dims_vec.push_back(1);
y_dims_n = 2;
y_broadcasted = true;
}
size_t M, N;
if (trans_x) {
M = x_dims_vec[x_dims_n - 1];
} else {
M = x_dims_vec[x_dims_n - 2];
}
if (trans_y) {
N = y_dims_vec[y_dims_n - 2];
} else {
N = y_dims_vec[y_dims_n - 1];
}
std::vector<int64_t> new_dims;
if (x_dims_n >= y_dims_n) {
new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2);
} else {
new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}
auto out_dims = framework::make_ddim(new_dims);
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};
class SolveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The first input tensor of solve op.");
AddInput("Y", "(Tensor), The second input tensor of solve op.");
AddOutput("Out", "(Tensor), The output tensor of solve op.");
AddComment(R"DOC(
Solve Operator.
This operator is used to computes the solution of a square system of
linear equations with a unique solution for input $X$ and $Y$.
The equation is:
$$Out = X^-1 * Y$$
)DOC");
}
};
class SolveOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class SolveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "solve");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "solve");
// reuse the linalg.solve forward output
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class SolveOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("solve_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
// reuse the linalg.solve forward output
retv->SetInput("Out", this->Output("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(solve, ops::SolveOp, ops::SolveOpMaker,
ops::SolveOpInferVarType,
ops::SolveOpGradMaker<paddle::framework::OpDesc>,
ops::SolveOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(solve_grad, ops::SolveGradOp);
REGISTER_OP_CPU_KERNEL(
solve, ops::SolveKernel<paddle::platform::CPUDeviceContext, float>,
ops::SolveKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
solve_grad, ops::SolveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SolveGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/solve_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(solve, ops::SolveKernel<plat::CUDADeviceContext, float>,
ops::SolveKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(solve_grad,
ops::SolveGradKernel<plat::CUDADeviceContext, float>,
ops::SolveGradKernel<plat::CUDADeviceContext, double>);
此差异已折叠。
...@@ -89,7 +89,9 @@ extern void *cublas_dso_handle; ...@@ -89,7 +89,9 @@ extern void *cublas_dso_handle;
__macro(cublasDgetrfBatched); \ __macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \ __macro(cublasDgetriBatched); \
__macro(cublasSmatinvBatched); \ __macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); __macro(cublasDmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
......
...@@ -106,6 +106,7 @@ from .tensor.linalg import matrix_power # noqa: F401 ...@@ -106,6 +106,7 @@ from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import svd # noqa: F401 from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import eigh # noqa: F401
from .tensor.linalg import pinv # noqa: F401 from .tensor.linalg import pinv # noqa: F401
from .tensor.linalg import solve # noqa: F401
from .tensor.logic import equal # noqa: F401 from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401
from .tensor.logic import greater_than # noqa: F401 from .tensor.logic import greater_than # noqa: F401
......
...@@ -964,6 +964,7 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) ...@@ -964,6 +964,7 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120)
set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120)
set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120)
set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120)
if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120)
......
此差异已折叠。
...@@ -48,6 +48,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ ...@@ -48,6 +48,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'lgamma', \ 'lgamma', \
'svd', \ 'svd', \
'matrix_power', \ 'matrix_power', \
'solve', \
] ]
NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\
......
...@@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401 ...@@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401
from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import norm # noqa: F401
from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import cond # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import solve # noqa: F401
from .tensor import inverse as inv # noqa: F401 from .tensor import inverse as inv # noqa: F401
from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import eigvals # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401
...@@ -35,5 +36,6 @@ __all__ = [ ...@@ -35,5 +36,6 @@ __all__ = [
'svd', 'svd',
'matrix_power', 'matrix_power',
'eigh', 'eigh',
'pinv' 'pinv',
'solve'
] ]
...@@ -51,6 +51,7 @@ from .linalg import multi_dot # noqa: F401 ...@@ -51,6 +51,7 @@ from .linalg import multi_dot # noqa: F401
from .linalg import svd # noqa: F401 from .linalg import svd # noqa: F401
from .linalg import eigh # noqa: F401 from .linalg import eigh # noqa: F401
from .linalg import pinv # noqa: F401 from .linalg import pinv # noqa: F401
from .linalg import solve # noqa: F401
from .logic import equal # noqa: F401 from .logic import equal # noqa: F401
from .logic import greater_equal # noqa: F401 from .logic import greater_equal # noqa: F401
from .logic import greater_than # noqa: F401 from .logic import greater_than # noqa: F401
...@@ -386,6 +387,7 @@ tensor_method_func = [ #noqa ...@@ -386,6 +387,7 @@ tensor_method_func = [ #noqa
'bitwise_not', 'bitwise_not',
'broadcast_tensors', 'broadcast_tensors',
'uniform_', 'uniform_',
'solve',
] ]
#this list used in math_op_patch.py for magic_method bind #this list used in math_op_patch.py for magic_method bind
......
...@@ -1969,3 +1969,60 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): ...@@ -1969,3 +1969,60 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None):
attrs={'trans_x': False, attrs={'trans_x': False,
'trans_y': True}, ) 'trans_y': True}, )
return out_2 return out_2
def solve(x, y, name=None):
r"""
Computes the solution of a square system of linear equations with a unique solution for input 'X' and 'Y'.
Let :math: `X` be a sqaure matrix or a batch of square matrices, :math:`Y` be
a vector/matrix or a batch of vectors/matrices, the equation should be:
.. math::
Out = X^-1 * Y
Specifically,
- This system of linear equations has one solution if and only if input 'X' is invertible.
Args:
x (Tensor): A square matrix or a batch of square matrices. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
y (Tensor): A vector/matrix or a batch of vectors/matrices. Its shape should be `[*, M, K]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
name(str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The solution of a square system of linear equations with a unique solution for input 'x' and 'y'.
Its data type should be the same as that of `x`.
Examples:
.. code-block:: python
# a square system of linear equations:
# 2*X0 + X1 = 9
# X0 + 2*X1 = 8
import paddle
import numpy as np
np_x = np.array([[3, 1],[1, 2]])
np_y = np.array([9, 8])
x = paddle.to_tensor(np_x, dtype="float64")
y = paddle.to_tensor(np_y, dtype="float64")
out = paddle.linalg.solve(x, y)
print(out)
# [2., 3.])
"""
if in_dygraph_mode():
return _C_ops.solve(x, y)
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("solve", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'solve')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'solve')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="solve", inputs={"X": x,
"Y": y}, outputs={"Out": out})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册