未验证 提交 3a81805b 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add new API/OP: paddle.linalg.triangular_solve (#36714) (#37551)

cherry-pick #36714
上级 4b41b8e9
......@@ -253,6 +253,12 @@ class Blas {
void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a,
int lda, int* ipiv, T** b, int ldb, int* info,
int batch_size) const;
// cuBlas triangular_solve
template <typename T>
void BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA,
CBLAS_DIAG diag, int M, int N, T alpha, const T** a, int lda,
T** b, int ldb, int batch_size) const;
#endif
private:
......@@ -414,6 +420,12 @@ class BlasT : private Blas<DeviceContext> {
void BatchedGETRS(ARGS... args) const {
Base()->template BatchedGETRS<T>(args...);
}
// triangular_solve
template <typename... ARGS>
void BatchedTRSM(ARGS... args) const {
Base()->template BatchedTRSM<T>(args...);
}
#endif
private:
......
......@@ -120,6 +120,11 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasSgetrsBatched(args...));
}
template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsmBatched(args...));
}
};
template <>
......@@ -194,6 +199,11 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cublasDgetrsBatched(args...));
}
template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsmBatched(args...));
}
};
template <>
......@@ -339,6 +349,19 @@ struct CUBlas<platform::complex<float>> {
reinterpret_cast<cuFloatComplex *>(C), ldc));
}
static void TRSM(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<float> *alpha,
const paddle::platform::complex<float> *A, int lda,
paddle::platform::complex<float> *B, int ldb) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsm(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex *>(A), lda,
reinterpret_cast<cuFloatComplex *>(B), ldb));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
......@@ -370,6 +393,20 @@ struct CUBlas<platform::complex<float>> {
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}
static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<float> *alpha,
const paddle::platform::complex<float> **A, int lda,
paddle::platform::complex<float> **B, int ldb,
int batch_size) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsmBatched(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuFloatComplex *>(alpha),
reinterpret_cast<const cuFloatComplex **>(A), lda,
reinterpret_cast<cuFloatComplex **>(B), ldb, batch_size));
}
};
template <>
......@@ -440,6 +477,33 @@ struct CUBlas<platform::complex<double>> {
reinterpret_cast<cuDoubleComplex *>(C), ldc));
}
static void TRSM(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<double> *alpha,
const paddle::platform::complex<double> *A, int lda,
paddle::platform::complex<double> *B, int ldb) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsm(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex *>(A), lda,
reinterpret_cast<cuDoubleComplex *>(B), ldb));
}
static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side,
cublasFillMode_t uplo, cublasOperation_t transa,
cublasDiagType_t diag, int m, int n,
const paddle::platform::complex<double> *alpha,
const paddle::platform::complex<double> **A, int lda,
paddle::platform::complex<double> **B, int ldb,
int batch_size) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsmBatched(
handle, side, uplo, transa, diag, m, n,
reinterpret_cast<const cuDoubleComplex *>(alpha),
reinterpret_cast<const cuDoubleComplex **>(A), lda,
reinterpret_cast<cuDoubleComplex **>(B), ldb, batch_size));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template <typename... ARGS>
......@@ -897,6 +961,30 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedTRSM(
CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T **A, int lda, T **B, int ldb,
int batch_size) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
cublasSideMode_t cuSide =
(side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT;
cublasFillMode_t cuUplo =
(uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER;
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasDiagType_t cuDiag =
(diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT;
context_.CublasCall([&](cublasHandle_t handle) {
CUBlas<T>::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M,
&alpha, A, lda, B, ldb, batch_size);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -434,6 +434,17 @@ struct CBlas<platform::complex<float>> {
a_, lda, b_, ldb, &beta, c_, ldc);
}
static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N,
paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, int lda,
paddle::platform::complex<float> *B, int ldb) {
const void *a_ = (const void *)(A);
void *b_ = static_cast<void *>(B);
platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N,
&alpha, a_, lda, b_, ldb);
}
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
......@@ -562,6 +573,17 @@ struct CBlas<platform::complex<double>> {
a_, lda, b_, ldb, &beta, c_, ldc);
}
static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo,
CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N,
paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, int lda,
paddle::platform::complex<double> *B, int ldb) {
const void *a_ = (const void *)(A);
void *b_ = static_cast<void *>(B);
platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N,
&alpha, a_, lda, b_, ldb);
}
template <typename... ARGS>
static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a,
CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K,
......@@ -682,6 +704,15 @@ struct CBlas<platform::complex<float>> {
cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA,
const CBLAS_DIAG diag, const int M, const int N,
const paddle::platform::complex<float> alpha,
const paddle::platform::complex<float> *A, const int lda,
paddle::platform::complex<double> *B, const int ldb) {
cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
};
template <>
......@@ -720,6 +751,15 @@ struct CBlas<platform::complex<double>> {
cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta,
C, ldc);
}
static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side,
const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA,
const CBLAS_DIAG diag, const int M, const int N,
const paddle::platform::complex<double> alpha,
const paddle::platform::complex<double> *A, const int lda,
paddle::platform::complex<double> *B, const int ldb) {
cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb);
}
};
#endif
......
......@@ -90,6 +90,12 @@ struct CUBlas<float> {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasSmatinvBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasStrsmBatched is not supported on HIP platform."));
}
};
template <>
......@@ -153,6 +159,12 @@ struct CUBlas<double> {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDmatinvBatched is not supported on HIP platform."));
}
template <typename... ARGS>
static void TRSM_BATCH(ARGS... args) {
PADDLE_THROW(platform::errors::Unimplemented(
"cublasDtrsmBatched is not supported on HIP platform."));
}
};
template <>
......@@ -730,6 +742,32 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
batch_size);
});
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedTRSM(
CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag,
int M, int N, T alpha, const T **A, int lda, T **B, int ldb,
int batch_size) const {
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
rocblas_side cuSide =
(side == CblasLeft) ? rocblas_side_right : rocblas_side_left;
rocblas_fill cuUplo =
(uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower;
// use CUBLAS_OP_C (conjugate transpose) for complex
rocblas_operation cuTransA = (transA == CblasNoTrans)
? rocblas_operation_none
: rocblas_operation_transpose;
rocblas_diagonal cuDiag =
(diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit;
context_.CublasCall([&](rocblas_handle handle) {
CUBlas<T>::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M,
&alpha, A, lda, B, ldb, batch_size);
});
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -34,6 +34,45 @@ class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;
template <typename T>
class TriangularSolveFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor* a, framework::Tensor* b, bool left,
bool upper, bool transpose, bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());
int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);
int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda,
b_data + i * N * M, ldb);
}
}
};
template class TriangularSolveFunctor<platform::CPUDeviceContext, float>;
template class TriangularSolveFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -163,6 +163,68 @@ class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;
template <typename T>
class TriangularSolveFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context, const Tensor* a,
Tensor* b, bool left, bool upper, bool transpose,
bool unitriangular) {
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
const T* a_data = a->data<T>();
T* b_data = b->mutable_data<T>(context.GetPlace());
int a_dim_size = a->dims().size();
int b_dim_size = b->dims().size();
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
auto lda = left ? std::max(1, M) : std::max(1, N);
auto ldb = std::max(1, N);
int batch_size = 1;
auto& a_dim = a->dims();
for (int i = 0; i < a_dim_size - 2; i++) {
batch_size *= a_dim[i];
}
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
if (batch_size <= 8 && M >= 64) {
for (auto i = 0; i < batch_size; i++) {
blas.TRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
a_data + i * M * M, lda, b_data + i * N * M, ldb);
}
} else {
std::vector<const T*> cpu_ptrs(batch_size * 2);
for (int i = 0; i < batch_size; ++i) {
cpu_ptrs[i] = a_data + i * M * M;
cpu_ptrs[i + batch_size] = b_data + i * M * N;
}
// Copy the addresses of A and tmp_b from host to device.
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()),
tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*), context.stream());
const T** gpu_a_ptrs =
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr());
T** gpu_b_ptrs =
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
blas.BatchedTRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
gpu_a_ptrs, lda, gpu_b_ptrs, ldb, batch_size);
}
}
};
template class TriangularSolveFunctor<platform::CUDADeviceContext, float>;
template class TriangularSolveFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -117,6 +117,14 @@ class MatrixSolveFunctor {
const framework::Tensor& b, framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class TriangularSolveFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* a,
framework::Tensor* b, bool left, bool upper, bool transpose,
bool unitriangular);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -49,9 +49,9 @@ struct IdentityFunctor {
};
template <typename DeviceContext, typename T>
void ReduceSumForSolveGrad(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, bool keep_dim,
const paddle::framework::ExecutionContext& ctx) {
void ReduceSumForSolve(const Tensor* input, Tensor* output,
const std::vector<int>& reduce_dims, bool keep_dim,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor>(*input, output, reduce_dims,
......@@ -185,36 +185,6 @@ static std::vector<int64_t> get_broadcast_batch_portion(
return batchPortion;
}
// necessary check before expand operation
static void expand_check(const Tensor& arg1,
std::vector<int64_t> expand_shape) {
auto rank = arg1.dims().size();
PADDLE_ENFORCE_GE(
rank, 1, platform::errors::InvalidArgument(
"The rank of the input 'X' for expand must be positive, "
"but the value received is %d.",
rank));
PADDLE_ENFORCE_LE(
rank, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The rank of the input 'X' for expand must be less than "
"or equal to %d, but the value received is %d.",
MAX_RANK_SUPPORTED, rank));
auto shape_size = static_cast<int>(expand_shape.size());
PADDLE_ENFORCE_GE(
shape_size, rank,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand must be "
"greater than or equal to the rank (%d) of the input 'X'.",
shape_size, rank));
PADDLE_ENFORCE_LE(
shape_size, MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The number (%d) of elements of 'shape' for expand must be "
"less than or equal to %d.",
shape_size, MAX_RANK_SUPPORTED));
}
// broadcast the batch dimensions of tensor x and tensor y.
static inline std::tuple<std::vector<int64_t>, std::vector<int64_t>>
get_broadcast_dims(const Tensor& x, const Tensor& y) {
......@@ -246,15 +216,13 @@ get_broadcast_dims(const Tensor& x, const Tensor& y) {
}
template <int Rank, typename T, typename DeviceContext>
void tensor_expand(const framework::ExecutionContext& context,
const Tensor& arg1, Tensor* out0,
std::vector<int64_t> expand_size) {
auto in_dims = arg1.dims();
auto expand_shape = expand_size;
auto vec_in_dims = framework::vectorize<int>(in_dims);
void expand_impl(const DeviceContext& context, const Tensor& in, Tensor* out,
const std::vector<int64_t>& expand_shape) {
auto vec_in_dims = framework::vectorize<int>(in.dims());
auto diff = expand_shape.size() - vec_in_dims.size();
vec_in_dims.insert(vec_in_dims.begin(), diff, 1);
std::vector<int> repeat_times(vec_in_dims.size());
for (size_t i = 0; i < vec_in_dims.size(); ++i) {
PADDLE_ENFORCE_NE(
expand_shape[i], 0,
......@@ -301,12 +269,11 @@ void tensor_expand(const framework::ExecutionContext& context,
out_dims[i] *= repeat_times[i];
}
out0->Resize(out_dims);
auto x = EigenTensor<T, Rank>::From(arg1, new_in_dims);
out0->mutable_data<T>(context.GetPlace());
auto y = EigenTensor<T, Rank>::From(*out0, out_dims);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out->Resize(out_dims);
out->mutable_data<T>(context.GetPlace());
auto x = EigenTensor<T, Rank>::From(in, new_in_dims);
auto y = EigenTensor<T, Rank>::From(*out, out_dims);
auto& place = *context.eigen_device();
// use 32-bit index to speed up
bool use_32bit_index = y.size() < Eigen::NumTraits<int>::highest();
if (use_32bit_index) {
......@@ -318,6 +285,41 @@ void tensor_expand(const framework::ExecutionContext& context,
}
}
template <typename T, typename DeviceContext>
void TensorExpand(const DeviceContext& context, const Tensor& in, Tensor* out,
const std::vector<int64_t>& expand_shape) {
// necessary check before expand operation
PADDLE_ENFORCE_GE(expand_shape.size(), in.dims().size(),
platform::errors::InvalidArgument(
"The size of 'expand_shape' (%d) should >= the input "
"Tensor's rank (%d).",
expand_shape.size(), in.dims().size()));
PADDLE_ENFORCE_LE(expand_shape.size(), MAX_RANK_SUPPORTED,
platform::errors::InvalidArgument(
"The size of 'expand_shape' (%d) should be <= %d",
expand_shape.size(), MAX_RANK_SUPPORTED));
switch (expand_shape.size()) {
case 1:
expand_impl<1, T, DeviceContext>(context, in, out, expand_shape);
break;
case 2:
expand_impl<2, T, DeviceContext>(context, in, out, expand_shape);
break;
case 3:
expand_impl<3, T, DeviceContext>(context, in, out, expand_shape);
break;
case 4:
expand_impl<4, T, DeviceContext>(context, in, out, expand_shape);
break;
case 5:
expand_impl<5, T, DeviceContext>(context, in, out, expand_shape);
break;
case 6:
expand_impl<6, T, DeviceContext>(context, in, out, expand_shape);
break;
}
}
template <typename DeviceContext, typename T>
static void linalg_solve(const framework::ExecutionContext& context,
const framework::Tensor* x, const framework::Tensor* y,
......@@ -356,69 +358,11 @@ static void linalg_solve(const framework::ExecutionContext& context,
std::tie(x_broadcast_dims, y_broadcast_dims) =
get_broadcast_dims(tmp_x, tmp_y);
expand_check(tmp_x, x_broadcast_dims);
expand_check(tmp_y, y_broadcast_dims);
Tensor tmp_x_bc;
Tensor tmp_y_bc;
auto tmp_x_rank = tmp_x.dims().size();
auto tmp_y_rank = tmp_y.dims().size();
auto rank_0 = std::max(tmp_x_rank, static_cast<int>(x_broadcast_dims.size()));
switch (rank_0) {
case 1:
tensor_expand<1, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
case 2:
tensor_expand<2, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
case 3:
tensor_expand<3, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
case 4:
tensor_expand<4, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
case 5:
tensor_expand<5, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
case 6:
tensor_expand<6, T, DeviceContext>(context, tmp_x, &tmp_x_bc,
x_broadcast_dims);
break;
}
TensorExpand<T, DeviceContext>(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims);
auto rank_1 = std::max(tmp_y_rank, static_cast<int>(y_broadcast_dims.size()));
switch (rank_1) {
case 1:
tensor_expand<1, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
case 2:
tensor_expand<2, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
case 3:
tensor_expand<3, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
case 4:
tensor_expand<4, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
case 5:
tensor_expand<5, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
case 6:
tensor_expand<6, T, DeviceContext>(context, tmp_y, &tmp_y_bc,
y_broadcast_dims);
break;
}
Tensor tmp_y_bc;
TensorExpand<T, DeviceContext>(dev_ctx, tmp_y, &tmp_y_bc, y_broadcast_dims);
auto x_dim = x->dims();
auto y_dim = y->dims();
......@@ -658,8 +602,8 @@ class SolveGradKernel : public framework::OpKernel<T> {
if (dy_help.dims().size() != dy->dims().size()) {
keep_dim = false;
}
ReduceSumForSolveGrad<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
keep_dim, ctx);
ReduceSumForSolve<DeviceContext, T>(&dy_help, dy, dy_reduce_dims,
keep_dim, ctx);
}
dy->Resize(y->dims());
}
......@@ -708,8 +652,8 @@ class SolveGradKernel : public framework::OpKernel<T> {
if (dx_help.dims().size() != dx->dims().size()) {
keep_dim = false;
}
ReduceSumForSolveGrad<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
keep_dim, ctx);
ReduceSumForSolve<DeviceContext, T>(&dx_help, dx, dx_reduce_dims,
keep_dim, ctx);
}
dx->Resize(input->dims());
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/operators/solve_op.h"
namespace paddle {
namespace operators {
class TriangularSolveOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TriangularSolve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "TriangularSolve");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TriangularSolve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_dims_n = x_dims.size();
auto y_dims_n = y_dims.size();
PADDLE_ENFORCE_GE(
x_dims_n, 2, platform::errors::InvalidArgument(
"The input tensor X's dimensions of TriangularSolveOp "
"should be >= 2. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
y_dims_n, 2, platform::errors::InvalidArgument(
"The input tensor Y's dimensions of TriangularSolveOp "
"should be >=2. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims.size(), y_dims));
PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1],
platform::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2], x_dims[x_dims_n - 1]));
std::vector<int64_t> x_dims_vec = paddle::framework::vectorize(x_dims);
std::vector<int64_t> y_dims_vec = paddle::framework::vectorize(y_dims);
std::vector<int64_t> x_dims_vec_cut(x_dims_vec.begin(),
x_dims_vec.end() - 2);
std::vector<int64_t> y_dims_vec_cut(y_dims_vec.begin(),
y_dims_vec.end() - 2);
std::vector<int64_t> expand_batch_portion =
get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut);
std::vector<int64_t> y_broadcast_dims({expand_batch_portion});
y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2],
y_dims_vec[y_dims_n - 1]});
// dim of 'Out' is the same with 'Y' after broadcast
ctx->SetOutputDim("Out", framework::make_ddim(y_broadcast_dims));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class TriangularSolveOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), The first input tensor of triangular solve op, which "
"is the triangular coefficient matrix.");
AddInput("Y",
"(Tensor), The second input tensor of triangular solve op, which "
"is multiple right-hand.");
AddOutput("Out", "(Tensor), The solution tensor of triangular solve op.");
AddAttr<bool>("upper",
"whether to solve the upper-triangular or the "
"lower-triangular system of equations")
.SetDefault(true);
AddAttr<bool>("transpose", "whether X should be transposed firstly.")
.SetDefault(false);
AddAttr<bool>("unitriangular", "whether X is unit triangular.")
.SetDefault(false);
AddComment(R"DOC(
Triangular Solve Operator.
This operator is used to computes the solution of equations with a triangular coefficient matrix.
The equation is:
$$Out = X^-1 * Y$$
)DOC");
}
};
class TriangularSolveOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Out"}};
return m;
}
};
class TriangularSolveGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "triangular_solve");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "triangular_solve");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
template <typename T>
class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("triangular_solve_grad");
retv->SetInput("X", this->Input("X"));
retv->SetInput("Y", this->Input("Y"));
retv->SetInput("Out", this->Output("Out"));
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(triangular_solve, ops::TriangularSolveOp,
ops::TriangularSolveOpMaker,
ops::TriangularSolveOpInferVarType,
ops::TriangularSolveOpGradMaker<paddle::framework::OpDesc>,
ops::TriangularSolveOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp);
REGISTER_OP_CPU_KERNEL(
triangular_solve,
ops::TriangularSolveKernel<paddle::platform::CPUDeviceContext, float>,
ops::TriangularSolveKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
triangular_solve_grad,
ops::TriangularSolveGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::TriangularSolveGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct MatrixReduceSumFunctor<platform::CUDADeviceContext, T> {
void operator()(const Tensor& in, Tensor* out,
const framework::ExecutionContext& ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = framework::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims =
framework::vectorize(out->dims());
auto out_size = out_dims.size();
std::vector<std::int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, CustomSum>(in, out, out_reduce_dims, stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
triangular_solve,
ops::TriangularSolveKernel<paddle::platform::CUDADeviceContext, float>,
ops::TriangularSolveKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
triangular_solve_grad,
ops::TriangularSolveGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::TriangularSolveGradKernel<paddle::platform::CUDADeviceContext,
double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
static void triangular_solve(const DeviceContext& context, const Tensor& x,
const Tensor& y, Tensor* out, bool upper,
bool transpose, bool unitriangular) {
// Tensor broadcast use eigen
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y);
Tensor x_bst(x.type());
TensorExpand<T, DeviceContext>(context, x, &x_bst, x_bst_dims_vec);
Tensor y_bst(y.type());
TensorExpand<T, DeviceContext>(context, y, &y_bst, y_bst_dims_vec);
// TriangularSolveFunctor performs calculations in-place
// x_clone should be a copy of 'x' after broadcast
// out should be a copy of 'y' after broadcast
Tensor x_clone(x.type());
x_clone.Resize(framework::make_ddim(x_bst_dims_vec));
x_clone.mutable_data<T>(context.GetPlace());
framework::TensorCopy(x_bst, context.GetPlace(), context, &x_clone);
out->Resize(framework::make_ddim(y_bst_dims_vec));
out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(y_bst, context.GetPlace(), context, out);
math::TriangularSolveFunctor<DeviceContext, T> functor;
functor(context, &x_clone, out, /*left=*/true, upper, transpose,
unitriangular);
}
template <typename DeviceContext, typename T>
class MatrixReduceSumFunctor {
public:
void operator()(const Tensor& input, Tensor* output,
const framework::ExecutionContext& ctx);
};
template <typename T>
class MatrixReduceSumFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const Tensor& in, Tensor* out,
const framework::ExecutionContext& ctx) {
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const std::vector<std::int64_t> in_dims = framework::vectorize(in.dims());
auto in_size = in_dims.size();
const std::vector<std::int64_t> out_dims =
framework::vectorize(out->dims());
auto out_size = out_dims.size();
std::vector<std::int64_t> out_bst_dims(in_size);
std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1);
std::copy(out_dims.data(), out_dims.data() + out_size,
out_bst_dims.data() + in_size - out_size);
out->Resize(framework::make_ddim(out_bst_dims));
std::vector<int> out_reduce_dims;
for (size_t idx = 0; idx <= in_size - 3; idx++) {
if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) {
out_reduce_dims.push_back(idx);
}
}
ReduceKernelFunctor<platform::CPUDeviceContext, T, SumFunctor>(
&in, out, out_reduce_dims, true, false, ctx)
.template apply<T>();
out->Resize(framework::make_ddim(out_dims));
}
};
template <typename DeviceContext, typename T>
class TriangularSolveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
bool upper = ctx.template Attr<bool>("upper");
bool transpose = ctx.template Attr<bool>("transpose");
bool unitriangular = ctx.template Attr<bool>("unitriangular");
const auto& dev_ctx = ctx.template device_context<DeviceContext>();
triangular_solve<DeviceContext, T>(dev_ctx, *x, *y, out, upper, transpose,
unitriangular);
}
};
template <typename DeviceContext, typename T>
class TriangularSolveGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<framework::Tensor>("X");
const auto* y = ctx.Input<framework::Tensor>("Y");
const auto* out = ctx.Input<framework::Tensor>("Out");
const auto* dout =
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
bool upper = ctx.template Attr<bool>("upper");
bool transpose = ctx.template Attr<bool>("transpose");
bool unitriangular = ctx.template Attr<bool>("unitriangular");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<int64_t> x_bst_dims_vec;
std::vector<int64_t> y_bst_dims_vec;
std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y);
Tensor dy_bst(y->type());
if (dy) {
dy->mutable_data<T>(y->dims(), dev_ctx.GetPlace());
dy_bst.Resize(framework::make_ddim(y_bst_dims_vec));
dy_bst.mutable_data<T>(dev_ctx.GetPlace());
// calculate x's conjugate for complex
Tensor x_conj(x->type());
platform::ForRange<DeviceContext> x_for_range(dev_ctx, x->numel());
math::ConjFunctor<T> x_functor(
x->data<T>(), x->numel(),
x_conj.mutable_data<T>(x->dims(), dev_ctx.GetPlace()));
x_for_range(x_functor);
// reuse forward to get dy_bst, and the result has been broadcated.
triangular_solve<DeviceContext, T>(dev_ctx, x_conj, *dout, &dy_bst, upper,
!transpose, unitriangular);
if (dy_bst.dims() == dy->dims()) {
framework::TensorCopy(dy_bst, dev_ctx.GetPlace(), dev_ctx, dy);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(dy_bst, dy, ctx);
dy->Resize(y->dims());
}
}
Tensor dx_bst(x->type());
if (dx) {
dx->mutable_data<T>(x->dims(), dev_ctx.GetPlace());
dx_bst.Resize(framework::make_ddim(x_bst_dims_vec));
dx_bst.mutable_data<T>(dev_ctx.GetPlace());
// calculate out's conjugate for complex
Tensor out_conj(out->type());
platform::ForRange<DeviceContext> out_for_range(dev_ctx, out->numel());
math::ConjFunctor<T> out_functor(
out->data<T>(), out->numel(),
out_conj.mutable_data<T>(out->dims(), dev_ctx.GetPlace()));
out_for_range(out_functor);
auto blas = math::GetBlas<DeviceContext, T>(ctx);
if (transpose) {
auto mat_dim_a =
math::CreateMatrixDescriptor(out_conj.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(dy_bst.dims(), 0, true);
blas.MatMul(out_conj, mat_dim_a, dy_bst, mat_dim_b, static_cast<T>(-1),
&dx_bst, static_cast<T>(0));
} else {
auto mat_dim_a = math::CreateMatrixDescriptor(dy_bst.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(out_conj.dims(), 0, true);
blas.MatMul(dy_bst, mat_dim_a, out_conj, mat_dim_b, static_cast<T>(-1),
&dx_bst, static_cast<T>(0));
}
Tensor dx_bst_upper(x->type());
// get upper or lower triangular
dx_bst_upper.Resize(dx_bst.dims());
dx_bst_upper.mutable_data<T>(dev_ctx.GetPlace());
const auto& dims = dx_bst.dims();
const auto H = dims[dims.size() - 2];
const auto W = dims[dims.size() - 1];
platform::ForRange<DeviceContext> x_for_range(dev_ctx, dx_bst.numel());
TrilTriuCompute<T> tril_triu_computer(dx_bst.data<T>(), unitriangular,
!upper, H, W,
dx_bst_upper.data<T>());
x_for_range(tril_triu_computer);
if (dx_bst_upper.dims() == dx->dims()) {
framework::TensorCopy(dx_bst_upper, dev_ctx.GetPlace(), dev_ctx, dx);
} else {
MatrixReduceSumFunctor<DeviceContext, T> functor;
functor(dx_bst_upper, dx, ctx);
dx->Resize(x->dims());
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -75,6 +75,8 @@ extern void *cublas_dso_handle;
__macro(cublasDgeam); \
__macro(cublasStrsm_v2); \
__macro(cublasDtrsm_v2); \
__macro(cublasCtrsm_v2); \
__macro(cublasZtrsm_v2); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
......@@ -84,6 +86,10 @@ extern void *cublas_dso_handle;
__macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \
__macro(cublasStrsmBatched); \
__macro(cublasDtrsmBatched); \
__macro(cublasCtrsmBatched); \
__macro(cublasZtrsmBatched); \
__macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
......
......@@ -25,7 +25,7 @@ namespace platform {
namespace dynload {
extern std::once_flag mklml_dso_flag;
extern void* mklml_dso_handle;
extern void *mklml_dso_handle;
/**
* The following macro definition can generate structs
......@@ -40,7 +40,7 @@ extern void* mklml_dso_handle;
std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
}); \
static void* p_##_name = dlsym(mklml_dso_handle, #__name); \
static void *p_##_name = dlsym(mklml_dso_handle, #__name); \
return reinterpret_cast<mklmlFunc>(p_##_name)(args...); \
} \
}; \
......@@ -67,6 +67,8 @@ extern void* mklml_dso_handle;
__macro(cblas_zgemv); \
__macro(cblas_strsm); \
__macro(cblas_dtrsm); \
__macro(cblas_ctrsm); \
__macro(cblas_ztrsm); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.w
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
import paddle
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard, core
paddle.enable_static()
# 2D + 2D , test 'upper'
class TestTriangularSolveOp(OpTest):
"""
case 1
"""
def config(self):
self.x_shape = [12, 12]
self.y_shape = [12, 10]
self.upper = True
self.transpose = False
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
self.output = np.linalg.solve(
np.triu(self.inputs['X']), self.inputs['Y'])
def setUp(self):
self.op_type = "triangular_solve"
self.config()
self.inputs = {
'X': np.random.random(self.x_shape).astype(self.dtype),
'Y': np.random.random(self.y_shape).astype(self.dtype)
}
self.attrs = {
'upper': self.upper,
'transpose': self.transpose,
'unitriangular': self.unitriangular,
}
self.set_output()
self.outputs = {'Out': self.output}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
# 2D(broadcast) + 3D, test 'transpose'
class TestTriangularSolveOp2(TestTriangularSolveOp):
"""
case 2
"""
def config(self):
self.x_shape = [10, 10]
self.y_shape = [3, 10, 8]
self.upper = False
self.transpose = True
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.tril(self.inputs['X']).transpose(1, 0)
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 3D(broadcast) + 3D
class TestTriangularSolveOp3(TestTriangularSolveOp):
"""
case 3
"""
def config(self):
self.x_shape = [1, 10, 10]
self.y_shape = [6, 10, 12]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 3D + 3D(broadcast), test 'transpose'
class TestTriangularSolveOp4(TestTriangularSolveOp):
"""
case 4
"""
def config(self):
self.x_shape = [3, 10, 10]
self.y_shape = [1, 10, 12]
self.upper = True
self.transpose = True
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.triu(self.inputs['X']).transpose(0, 2, 1)
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 2D + 2D , test 'unitriangular' specially
class TestTriangularSolveOp5(TestTriangularSolveOp):
"""
case 5
"""
def config(self):
self.x_shape = [10, 10]
self.y_shape = [10, 10]
self.upper = True
self.transpose = False
self.unitriangular = True
self.dtype = "float64"
def set_output(self):
x = np.triu(self.inputs['X'])
np.fill_diagonal(x, 1.)
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
def test_check_grad_normal(self):
x = np.triu(self.inputs['X'])
np.fill_diagonal(x, 1.)
grad_out = np.ones([10, 10]).astype('float64')
grad_y = np.linalg.solve(x.transpose(1, 0), grad_out)
grad_x = -np.matmul(grad_y, self.output.transpose(1, 0))
grad_x = np.triu(grad_x)
np.fill_diagonal(grad_x, 0.)
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[grad_x, grad_y],
user_defined_grad_outputs=[grad_out])
# 4D(broadcast) + 4D(broadcast)
class TestTriangularSolveOp6(TestTriangularSolveOp):
"""
case 6
"""
def config(self):
self.x_shape = [1, 3, 10, 10]
self.y_shape = [2, 1, 10, 5]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 3D(broadcast) + 4D(broadcast), test 'upper'
class TestTriangularSolveOp7(TestTriangularSolveOp):
"""
case 7
"""
def config(self):
self.x_shape = [2, 10, 10]
self.y_shape = [5, 1, 10, 2]
self.upper = True
self.transpose = True
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.triu(self.inputs['X']).transpose(0, 2, 1)
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 3D(broadcast) + 5D
class TestTriangularSolveOp8(TestTriangularSolveOp):
"""
case 8
"""
def config(self):
self.x_shape = [12, 3, 3]
self.y_shape = [2, 3, 12, 3, 2]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.linalg.solve(x, y)
# 5D + 4D(broadcast)
class TestTriangularSolveOp9(TestTriangularSolveOp):
"""
case 9
"""
def config(self):
self.x_shape = [2, 4, 2, 3, 3]
self.y_shape = [4, 1, 3, 10]
self.upper = False
self.transpose = False
self.unitriangular = False
self.dtype = "float64"
def set_output(self):
x = np.tril(self.inputs['X'])
y = self.inputs['Y']
self.output = np.matmul(np.linalg.inv(x), y)
class TestTriangularSolveAPI(unittest.TestCase):
def setUp(self):
np.random.seed(2021)
self.place = [paddle.CPUPlace()]
self.dtype = "float64"
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
x = fluid.data(name="x", shape=[3, 3], dtype=self.dtype)
y = fluid.data(name="y", shape=[3, 2], dtype=self.dtype)
z = paddle.linalg.triangular_solve(x, y)
x_np = np.random.random([3, 3]).astype(self.dtype)
y_np = np.random.random([3, 2]).astype(self.dtype)
z_np = np.linalg.solve(np.triu(x_np), y_np)
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"x": x_np,
"y": y_np},
fetch_list=[z])
self.assertTrue(np.allclose(fetches[0], z_np))
def test_static(self):
for place in self.place:
self.check_static_result(place=place)
def test_dygraph(self):
def run(place):
paddle.disable_static(place)
x_np = np.random.random([3, 3]).astype(self.dtype)
y_np = np.random.random([3, 2]).astype(self.dtype)
z_np = np.linalg.solve(np.tril(x_np), y_np)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
z = paddle.linalg.triangular_solve(x, y, upper=False)
self.assertTrue(np.allclose(z_np, z.numpy()))
self.assertEqual(z_np.shape, z.numpy().shape)
paddle.enable_static()
for place in self.place:
run(place)
class TestTriangularSolveOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of solve_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
y1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, paddle.linalg.triangular_solve, x1, y1)
# The data type of input must be float32 or float64.
x2 = fluid.data(name="x2", shape=[30, 30], dtype="bool")
y2 = fluid.data(name="y2", shape=[30, 10], dtype="bool")
self.assertRaises(TypeError, paddle.linalg.triangular_solve, x2, y2)
x3 = fluid.data(name="x3", shape=[30, 30], dtype="int32")
y3 = fluid.data(name="y3", shape=[30, 10], dtype="int32")
self.assertRaises(TypeError, paddle.linalg.triangular_solve, x3, y3)
x4 = fluid.data(name="x4", shape=[30, 30], dtype="float16")
y4 = fluid.data(name="y4", shape=[30, 10], dtype="float16")
self.assertRaises(TypeError, paddle.linalg.triangular_solve, x4, y4)
# The number of dimensions of input'X must be >= 2.
x5 = fluid.data(name="x5", shape=[30], dtype="float64")
y5 = fluid.data(name="y5", shape=[30, 30], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.triangular_solve, x5,
y5)
# The number of dimensions of input'Y must be >= 2.
x6 = fluid.data(name="x6", shape=[30, 30], dtype="float64")
y6 = fluid.data(name="y6", shape=[30], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.triangular_solve, x6,
y6)
# The inner-most 2 dimensions of input'X should be equal to each other
x7 = fluid.data(name="x7", shape=[2, 3, 4], dtype="float64")
y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64")
self.assertRaises(ValueError, paddle.linalg.triangular_solve, x7,
y7)
if __name__ == "__main__":
unittest.main()
......@@ -29,6 +29,7 @@ from .tensor.linalg import eigvalsh
from .tensor.linalg import det
from .tensor.linalg import slogdet
from .tensor.linalg import pinv
from .tensor.linalg import triangular_solve
__all__ = [
'cholesky', #noqa
......@@ -47,5 +48,6 @@ __all__ = [
'eigh',
'eigvalsh',
'pinv',
'solve'
'solve',
'triangular_solve',
]
......@@ -397,6 +397,7 @@ tensor_method_func = [ #noqa
'uniform_',
'multi_dot',
'solve',
'triangular_solve'
]
#this list used in math_op_patch.py for magic_method bind
......
......@@ -2315,6 +2315,79 @@ def solve(x, y, name=None):
return out
def triangular_solve(x,
y,
upper=True,
transpose=False,
unitriangular=False,
name=None):
r"""
Computes the solution of a system of equations with a triangular coefficient matrix `x` and
multiple right-hand sides `y` .
Input `x` and `y` is 2D matrices or batches of 2D matrices. If the inputs are batches, the outputs
is also batches.
Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
unitriangular (bool, optional): whether `x` is unit triangular. If True, the diagonal elements of `x` are assumed
to be 1 and not referenced from `x` . Default: False.
name(str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The solution of the system of equations. Its data type should be the same as that of `x`.
Examples:
.. code-block:: python
# a square system of linear equations:
# x1 + x2 + x3 = 0
# 2*x2 + x3 = -9
# -x3 = 5
import paddle
import numpy as np
x = paddle.to_tensor([[1, 1, 1],
[0, 2, 1],
[0, 0,-1]], dtype="float64")
y = paddle.to_tensor([[0], [-9], [5]], dtype="float64")
out = paddle.linalg.triangular_solve(x, y, upper=True)
print(out)
# [7, -2, -5]
"""
if in_dygraph_mode():
return _C_ops.triangular_solve(x, y, 'upper', upper, 'transpose',
transpose, 'unitriangular',
unitriangular)
inputs = {"X": [x], "Y": [y]}
helper = LayerHelper("triangular_solve", **locals())
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'triangular_solve')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'triangular_solve')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='triangular_solve',
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs={
'upper': upper,
'transpose': transpose,
'unitriangular': unitriangular
})
return out
def eigvalsh(x, UPLO='L', name=None):
"""
Computes the eigenvalues of a
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册