Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a81805b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a81805b
编写于
11月 26, 2021
作者:
zhouweiwei2014
提交者:
GitHub
11月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new API/OP: paddle.linalg.triangular_solve (#36714) (#37551)
cherry-pick #36714
上级
4b41b8e9
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
1245 addition
and
113 deletion
+1245
-113
paddle/fluid/operators/math/blas.h
paddle/fluid/operators/math/blas.h
+12
-0
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+88
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+40
-0
paddle/fluid/operators/math/blas_impl.hip.h
paddle/fluid/operators/math/blas_impl.hip.h
+38
-0
paddle/fluid/operators/math/matrix_solve.cc
paddle/fluid/operators/math/matrix_solve.cc
+39
-0
paddle/fluid/operators/math/matrix_solve.cu.cc
paddle/fluid/operators/math/matrix_solve.cu.cc
+62
-0
paddle/fluid/operators/math/matrix_solve.h
paddle/fluid/operators/math/matrix_solve.h
+8
-0
paddle/fluid/operators/solve_op.h
paddle/fluid/operators/solve_op.h
+54
-110
paddle/fluid/operators/triangular_solve_op.cc
paddle/fluid/operators/triangular_solve_op.cc
+187
-0
paddle/fluid/operators/triangular_solve_op.cu
paddle/fluid/operators/triangular_solve_op.cu
+64
-0
paddle/fluid/operators/triangular_solve_op.h
paddle/fluid/operators/triangular_solve_op.h
+227
-0
paddle/fluid/platform/dynload/cublas.h
paddle/fluid/platform/dynload/cublas.h
+6
-0
paddle/fluid/platform/dynload/mklml.h
paddle/fluid/platform/dynload/mklml.h
+4
-2
python/paddle/fluid/tests/unittests/test_triangular_solve_op.py
.../paddle/fluid/tests/unittests/test_triangular_solve_op.py
+339
-0
python/paddle/linalg.py
python/paddle/linalg.py
+3
-1
python/paddle/tensor/__init__.py
python/paddle/tensor/__init__.py
+1
-0
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+73
-0
未找到文件。
paddle/fluid/operators/math/blas.h
浏览文件 @
3a81805b
...
...
@@ -253,6 +253,12 @@ class Blas {
void
BatchedGETRS
(
CBLAS_TRANSPOSE
trans
,
int
n
,
int
nrhs
,
const
T
**
a
,
int
lda
,
int
*
ipiv
,
T
**
b
,
int
ldb
,
int
*
info
,
int
batch_size
)
const
;
// cuBlas triangular_solve
template
<
typename
T
>
void
BatchedTRSM
(
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
CBLAS_TRANSPOSE
transA
,
CBLAS_DIAG
diag
,
int
M
,
int
N
,
T
alpha
,
const
T
**
a
,
int
lda
,
T
**
b
,
int
ldb
,
int
batch_size
)
const
;
#endif
private:
...
...
@@ -414,6 +420,12 @@ class BlasT : private Blas<DeviceContext> {
void
BatchedGETRS
(
ARGS
...
args
)
const
{
Base
()
->
template
BatchedGETRS
<
T
>(
args
...);
}
// triangular_solve
template
<
typename
...
ARGS
>
void
BatchedTRSM
(
ARGS
...
args
)
const
{
Base
()
->
template
BatchedTRSM
<
T
>(
args
...);
}
#endif
private:
...
...
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
3a81805b
...
...
@@ -120,6 +120,11 @@ struct CUBlas<float> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasSgetrsBatched
(
args
...));
}
template
<
typename
...
ARGS
>
static
void
TRSM_BATCH
(
ARGS
...
args
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasStrsmBatched
(
args
...));
}
};
template
<
>
...
...
@@ -194,6 +199,11 @@ struct CUBlas<double> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasDgetrsBatched
(
args
...));
}
template
<
typename
...
ARGS
>
static
void
TRSM_BATCH
(
ARGS
...
args
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasDtrsmBatched
(
args
...));
}
};
template
<
>
...
...
@@ -339,6 +349,19 @@ struct CUBlas<platform::complex<float>> {
reinterpret_cast
<
cuFloatComplex
*>
(
C
),
ldc
));
}
static
void
TRSM
(
cublasHandle_t
handle
,
cublasSideMode_t
side
,
cublasFillMode_t
uplo
,
cublasOperation_t
transa
,
cublasDiagType_t
diag
,
int
m
,
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
int
lda
,
paddle
::
platform
::
complex
<
float
>
*
B
,
int
ldb
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasCtrsm
(
handle
,
side
,
uplo
,
transa
,
diag
,
m
,
n
,
reinterpret_cast
<
const
cuFloatComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuFloatComplex
*>
(
A
),
lda
,
reinterpret_cast
<
cuFloatComplex
*>
(
B
),
ldb
));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template
<
typename
...
ARGS
>
...
...
@@ -370,6 +393,20 @@ struct CUBlas<platform::complex<float>> {
"cublasGemmEx is not supported on cuda <= 7.5"
));
#endif
}
static
void
TRSM_BATCH
(
cublasHandle_t
handle
,
cublasSideMode_t
side
,
cublasFillMode_t
uplo
,
cublasOperation_t
transa
,
cublasDiagType_t
diag
,
int
m
,
int
n
,
const
paddle
::
platform
::
complex
<
float
>
*
alpha
,
const
paddle
::
platform
::
complex
<
float
>
**
A
,
int
lda
,
paddle
::
platform
::
complex
<
float
>
**
B
,
int
ldb
,
int
batch_size
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasCtrsmBatched
(
handle
,
side
,
uplo
,
transa
,
diag
,
m
,
n
,
reinterpret_cast
<
const
cuFloatComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuFloatComplex
**>
(
A
),
lda
,
reinterpret_cast
<
cuFloatComplex
**>
(
B
),
ldb
,
batch_size
));
}
};
template
<
>
...
...
@@ -440,6 +477,33 @@ struct CUBlas<platform::complex<double>> {
reinterpret_cast
<
cuDoubleComplex
*>
(
C
),
ldc
));
}
static
void
TRSM
(
cublasHandle_t
handle
,
cublasSideMode_t
side
,
cublasFillMode_t
uplo
,
cublasOperation_t
transa
,
cublasDiagType_t
diag
,
int
m
,
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
int
lda
,
paddle
::
platform
::
complex
<
double
>
*
B
,
int
ldb
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasZtrsm
(
handle
,
side
,
uplo
,
transa
,
diag
,
m
,
n
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuDoubleComplex
*>
(
A
),
lda
,
reinterpret_cast
<
cuDoubleComplex
*>
(
B
),
ldb
));
}
static
void
TRSM_BATCH
(
cublasHandle_t
handle
,
cublasSideMode_t
side
,
cublasFillMode_t
uplo
,
cublasOperation_t
transa
,
cublasDiagType_t
diag
,
int
m
,
int
n
,
const
paddle
::
platform
::
complex
<
double
>
*
alpha
,
const
paddle
::
platform
::
complex
<
double
>
**
A
,
int
lda
,
paddle
::
platform
::
complex
<
double
>
**
B
,
int
ldb
,
int
batch_size
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cublasZtrsmBatched
(
handle
,
side
,
uplo
,
transa
,
diag
,
m
,
n
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
alpha
),
reinterpret_cast
<
const
cuDoubleComplex
**>
(
A
),
lda
,
reinterpret_cast
<
cuDoubleComplex
**>
(
B
),
ldb
,
batch_size
));
}
// NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply.
// https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode
template
<
typename
...
ARGS
>
...
...
@@ -897,6 +961,30 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
});
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
BatchedTRSM
(
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
CBLAS_TRANSPOSE
transA
,
CBLAS_DIAG
diag
,
int
M
,
int
N
,
T
alpha
,
const
T
**
A
,
int
lda
,
T
**
B
,
int
ldb
,
int
batch_size
)
const
{
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
cublasSideMode_t
cuSide
=
(
side
==
CblasLeft
)
?
CUBLAS_SIDE_RIGHT
:
CUBLAS_SIDE_LEFT
;
cublasFillMode_t
cuUplo
=
(
uplo
==
CblasLower
)
?
CUBLAS_FILL_MODE_UPPER
:
CUBLAS_FILL_MODE_LOWER
;
// use CUBLAS_OP_C (conjugate transpose) for complex
cublasOperation_t
cuTransA
=
(
transA
==
CblasNoTrans
)
?
CUBLAS_OP_N
:
CUBLAS_OP_T
;
cublasDiagType_t
cuDiag
=
(
diag
==
CblasUnit
)
?
CUBLAS_DIAG_UNIT
:
CUBLAS_DIAG_NON_UNIT
;
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
TRSM_BATCH
(
handle
,
cuSide
,
cuUplo
,
cuTransA
,
cuDiag
,
N
,
M
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
batch_size
);
});
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
3a81805b
...
...
@@ -434,6 +434,17 @@ struct CBlas<platform::complex<float>> {
a_
,
lda
,
b_
,
ldb
,
&
beta
,
c_
,
ldc
);
}
static
void
TRSM
(
CBLAS_LAYOUT
layout
,
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
CBLAS_TRANSPOSE
trans_a
,
CBLAS_DIAG
diag
,
int
M
,
int
N
,
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
int
lda
,
paddle
::
platform
::
complex
<
float
>
*
B
,
int
ldb
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
void
*
b_
=
static_cast
<
void
*>
(
B
);
platform
::
dynload
::
cblas_ctrsm
(
layout
,
side
,
uplo
,
trans_a
,
diag
,
M
,
N
,
&
alpha
,
a_
,
lda
,
b_
,
ldb
);
}
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
*
trans_a
,
CBLAS_TRANSPOSE
*
trans_b
,
int
*
M
,
int
*
N
,
int
*
K
,
...
...
@@ -562,6 +573,17 @@ struct CBlas<platform::complex<double>> {
a_
,
lda
,
b_
,
ldb
,
&
beta
,
c_
,
ldc
);
}
static
void
TRSM
(
CBLAS_LAYOUT
layout
,
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
CBLAS_TRANSPOSE
trans_a
,
CBLAS_DIAG
diag
,
int
M
,
int
N
,
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
int
lda
,
paddle
::
platform
::
complex
<
double
>
*
B
,
int
ldb
)
{
const
void
*
a_
=
(
const
void
*
)(
A
);
void
*
b_
=
static_cast
<
void
*>
(
B
);
platform
::
dynload
::
cblas_ztrsm
(
layout
,
side
,
uplo
,
trans_a
,
diag
,
M
,
N
,
&
alpha
,
a_
,
lda
,
b_
,
ldb
);
}
template
<
typename
...
ARGS
>
static
void
GEMM_BATCH
(
CBLAS_LAYOUT
layout
,
CBLAS_TRANSPOSE
*
trans_a
,
CBLAS_TRANSPOSE
*
trans_b
,
int
*
M
,
int
*
N
,
int
*
K
,
...
...
@@ -682,6 +704,15 @@ struct CBlas<platform::complex<float>> {
cblas_cgemm
(
layout
,
TransA
,
TransB
,
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
&
beta
,
C
,
ldc
);
}
static
void
TRSM
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_SIDE
side
,
const
CBLAS_UPLO
uplo
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_DIAG
diag
,
const
int
M
,
const
int
N
,
const
paddle
::
platform
::
complex
<
float
>
alpha
,
const
paddle
::
platform
::
complex
<
float
>
*
A
,
const
int
lda
,
paddle
::
platform
::
complex
<
double
>
*
B
,
const
int
ldb
)
{
cblas_ctrsm
(
layout
,
side
,
uplo
,
transA
,
diag
,
M
,
N
,
&
alpha
,
A
,
lda
,
B
,
ldb
);
}
};
template
<
>
...
...
@@ -720,6 +751,15 @@ struct CBlas<platform::complex<double>> {
cblas_zgemm
(
layout
,
TransA
,
TransB
,
M
,
N
,
K
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
&
beta
,
C
,
ldc
);
}
static
void
TRSM
(
const
CBLAS_LAYOUT
layout
,
const
CBLAS_SIDE
side
,
const
CBLAS_UPLO
uplo
,
const
CBLAS_TRANSPOSE
transA
,
const
CBLAS_DIAG
diag
,
const
int
M
,
const
int
N
,
const
paddle
::
platform
::
complex
<
double
>
alpha
,
const
paddle
::
platform
::
complex
<
double
>
*
A
,
const
int
lda
,
paddle
::
platform
::
complex
<
double
>
*
B
,
const
int
ldb
)
{
cblas_ztrsm
(
layout
,
side
,
uplo
,
transA
,
diag
,
M
,
N
,
&
alpha
,
A
,
lda
,
B
,
ldb
);
}
};
#endif
...
...
paddle/fluid/operators/math/blas_impl.hip.h
浏览文件 @
3a81805b
...
...
@@ -90,6 +90,12 @@ struct CUBlas<float> {
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasSmatinvBatched is not supported on HIP platform."
));
}
template
<
typename
...
ARGS
>
static
void
TRSM_BATCH
(
ARGS
...
args
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasStrsmBatched is not supported on HIP platform."
));
}
};
template
<
>
...
...
@@ -153,6 +159,12 @@ struct CUBlas<double> {
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasDmatinvBatched is not supported on HIP platform."
));
}
template
<
typename
...
ARGS
>
static
void
TRSM_BATCH
(
ARGS
...
args
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"cublasDtrsmBatched is not supported on HIP platform."
));
}
};
template
<
>
...
...
@@ -730,6 +742,32 @@ void Blas<platform::CUDADeviceContext>::BatchedGETRS(
batch_size
);
});
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
BatchedTRSM
(
CBLAS_SIDE
side
,
CBLAS_UPLO
uplo
,
CBLAS_TRANSPOSE
transA
,
CBLAS_DIAG
diag
,
int
M
,
int
N
,
T
alpha
,
const
T
**
A
,
int
lda
,
T
**
B
,
int
ldb
,
int
batch_size
)
const
{
// solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'`
// where ' stands for transpose
rocblas_side
cuSide
=
(
side
==
CblasLeft
)
?
rocblas_side_right
:
rocblas_side_left
;
rocblas_fill
cuUplo
=
(
uplo
==
CblasLower
)
?
rocblas_fill_upper
:
rocblas_fill_lower
;
// use CUBLAS_OP_C (conjugate transpose) for complex
rocblas_operation
cuTransA
=
(
transA
==
CblasNoTrans
)
?
rocblas_operation_none
:
rocblas_operation_transpose
;
rocblas_diagonal
cuDiag
=
(
diag
==
CblasUnit
)
?
rocblas_diagonal_unit
:
rocblas_diagonal_non_unit
;
context_
.
CublasCall
([
&
](
rocblas_handle
handle
)
{
CUBlas
<
T
>::
TRSM_BATCH
(
handle
,
cuSide
,
cuUplo
,
cuTransA
,
cuDiag
,
N
,
M
,
&
alpha
,
A
,
lda
,
B
,
ldb
,
batch_size
);
});
}
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/matrix_solve.cc
浏览文件 @
3a81805b
...
...
@@ -34,6 +34,45 @@ class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
template
class
MatrixSolveFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
MatrixSolveFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
<
typename
T
>
class
TriangularSolveFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
*
a
,
framework
::
Tensor
*
b
,
bool
left
,
bool
upper
,
bool
transpose
,
bool
unitriangular
)
{
CBLAS_SIDE
side
=
left
?
CblasLeft
:
CblasRight
;
CBLAS_UPLO
uplo
=
upper
?
CblasUpper
:
CblasLower
;
CBLAS_TRANSPOSE
transA
=
transpose
?
CblasTrans
:
CblasNoTrans
;
CBLAS_DIAG
diag
=
unitriangular
?
CblasUnit
:
CblasNonUnit
;
const
T
*
a_data
=
a
->
data
<
T
>
();
T
*
b_data
=
b
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
a_dim_size
=
a
->
dims
().
size
();
int
b_dim_size
=
b
->
dims
().
size
();
int
M
=
static_cast
<
int
>
(
b
->
dims
()[
b_dim_size
-
2
]);
int
N
=
static_cast
<
int
>
(
b
->
dims
()[
b_dim_size
-
1
]);
auto
lda
=
left
?
std
::
max
(
1
,
M
)
:
std
::
max
(
1
,
N
);
auto
ldb
=
std
::
max
(
1
,
N
);
int
batch_size
=
1
;
auto
&
a_dim
=
a
->
dims
();
for
(
int
i
=
0
;
i
<
a_dim_size
-
2
;
i
++
)
{
batch_size
*=
a_dim
[
i
];
}
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
blas
.
TRSM
(
side
,
uplo
,
transA
,
diag
,
M
,
N
,
T
(
1
),
a_data
+
i
*
M
*
M
,
lda
,
b_data
+
i
*
N
*
M
,
ldb
);
}
}
};
template
class
TriangularSolveFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
TriangularSolveFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/matrix_solve.cu.cc
浏览文件 @
3a81805b
...
...
@@ -163,6 +163,68 @@ class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
template
class
MatrixSolveFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MatrixSolveFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
<
typename
T
>
class
TriangularSolveFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
Tensor
*
a
,
Tensor
*
b
,
bool
left
,
bool
upper
,
bool
transpose
,
bool
unitriangular
)
{
CBLAS_SIDE
side
=
left
?
CblasLeft
:
CblasRight
;
CBLAS_UPLO
uplo
=
upper
?
CblasUpper
:
CblasLower
;
CBLAS_TRANSPOSE
transA
=
transpose
?
CblasTrans
:
CblasNoTrans
;
CBLAS_DIAG
diag
=
unitriangular
?
CblasUnit
:
CblasNonUnit
;
const
T
*
a_data
=
a
->
data
<
T
>
();
T
*
b_data
=
b
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
a_dim_size
=
a
->
dims
().
size
();
int
b_dim_size
=
b
->
dims
().
size
();
int
M
=
static_cast
<
int
>
(
b
->
dims
()[
b_dim_size
-
2
]);
int
N
=
static_cast
<
int
>
(
b
->
dims
()[
b_dim_size
-
1
]);
auto
lda
=
left
?
std
::
max
(
1
,
M
)
:
std
::
max
(
1
,
N
);
auto
ldb
=
std
::
max
(
1
,
N
);
int
batch_size
=
1
;
auto
&
a_dim
=
a
->
dims
();
for
(
int
i
=
0
;
i
<
a_dim_size
-
2
;
i
++
)
{
batch_size
*=
a_dim
[
i
];
}
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
if
(
batch_size
<=
8
&&
M
>=
64
)
{
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
blas
.
TRSM
(
side
,
uplo
,
transA
,
diag
,
M
,
N
,
static_cast
<
T
>
(
1.0
),
a_data
+
i
*
M
*
M
,
lda
,
b_data
+
i
*
N
*
M
,
ldb
);
}
}
else
{
std
::
vector
<
const
T
*>
cpu_ptrs
(
batch_size
*
2
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cpu_ptrs
[
i
]
=
a_data
+
i
*
M
*
M
;
cpu_ptrs
[
i
+
batch_size
]
=
b_data
+
i
*
M
*
N
;
}
// Copy the addresses of A and tmp_b from host to device.
memory
::
allocation
::
AllocationPtr
tmp_gpu_ptrs_data
=
memory
::
Alloc
(
context
,
cpu_ptrs
.
size
()
*
sizeof
(
T
*
));
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
context
.
GetPlace
()),
tmp_gpu_ptrs_data
->
ptr
(),
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
cpu_ptrs
.
data
()),
cpu_ptrs
.
size
()
*
sizeof
(
T
*
),
context
.
stream
());
const
T
**
gpu_a_ptrs
=
reinterpret_cast
<
const
T
**>
(
tmp_gpu_ptrs_data
->
ptr
());
T
**
gpu_b_ptrs
=
reinterpret_cast
<
T
**>
(
tmp_gpu_ptrs_data
->
ptr
())
+
batch_size
;
blas
.
BatchedTRSM
(
side
,
uplo
,
transA
,
diag
,
M
,
N
,
static_cast
<
T
>
(
1.0
),
gpu_a_ptrs
,
lda
,
gpu_b_ptrs
,
ldb
,
batch_size
);
}
}
};
template
class
TriangularSolveFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
TriangularSolveFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/matrix_solve.h
浏览文件 @
3a81805b
...
...
@@ -117,6 +117,14 @@ class MatrixSolveFunctor {
const
framework
::
Tensor
&
b
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
TriangularSolveFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
*
a
,
framework
::
Tensor
*
b
,
bool
left
,
bool
upper
,
bool
transpose
,
bool
unitriangular
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/solve_op.h
浏览文件 @
3a81805b
...
...
@@ -49,7 +49,7 @@ struct IdentityFunctor {
};
template
<
typename
DeviceContext
,
typename
T
>
void
ReduceSumForSolve
Grad
(
const
Tensor
*
input
,
Tensor
*
output
,
void
ReduceSumForSolve
(
const
Tensor
*
input
,
Tensor
*
output
,
const
std
::
vector
<
int
>&
reduce_dims
,
bool
keep_dim
,
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
{
#if defined(__NVCC__) || defined(__HIPCC__)
...
...
@@ -185,36 +185,6 @@ static std::vector<int64_t> get_broadcast_batch_portion(
return
batchPortion
;
}
// necessary check before expand operation
static
void
expand_check
(
const
Tensor
&
arg1
,
std
::
vector
<
int64_t
>
expand_shape
)
{
auto
rank
=
arg1
.
dims
().
size
();
PADDLE_ENFORCE_GE
(
rank
,
1
,
platform
::
errors
::
InvalidArgument
(
"The rank of the input 'X' for expand must be positive, "
"but the value received is %d."
,
rank
));
PADDLE_ENFORCE_LE
(
rank
,
MAX_RANK_SUPPORTED
,
platform
::
errors
::
InvalidArgument
(
"The rank of the input 'X' for expand must be less than "
"or equal to %d, but the value received is %d."
,
MAX_RANK_SUPPORTED
,
rank
));
auto
shape_size
=
static_cast
<
int
>
(
expand_shape
.
size
());
PADDLE_ENFORCE_GE
(
shape_size
,
rank
,
platform
::
errors
::
InvalidArgument
(
"The number (%d) of elements of 'shape' for expand must be "
"greater than or equal to the rank (%d) of the input 'X'."
,
shape_size
,
rank
));
PADDLE_ENFORCE_LE
(
shape_size
,
MAX_RANK_SUPPORTED
,
platform
::
errors
::
InvalidArgument
(
"The number (%d) of elements of 'shape' for expand must be "
"less than or equal to %d."
,
shape_size
,
MAX_RANK_SUPPORTED
));
}
// broadcast the batch dimensions of tensor x and tensor y.
static
inline
std
::
tuple
<
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
get_broadcast_dims
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
...
...
@@ -246,15 +216,13 @@ get_broadcast_dims(const Tensor& x, const Tensor& y) {
}
template
<
int
Rank
,
typename
T
,
typename
DeviceContext
>
void
tensor_expand
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
&
arg1
,
Tensor
*
out0
,
std
::
vector
<
int64_t
>
expand_size
)
{
auto
in_dims
=
arg1
.
dims
();
auto
expand_shape
=
expand_size
;
auto
vec_in_dims
=
framework
::
vectorize
<
int
>
(
in_dims
);
void
expand_impl
(
const
DeviceContext
&
context
,
const
Tensor
&
in
,
Tensor
*
out
,
const
std
::
vector
<
int64_t
>&
expand_shape
)
{
auto
vec_in_dims
=
framework
::
vectorize
<
int
>
(
in
.
dims
());
auto
diff
=
expand_shape
.
size
()
-
vec_in_dims
.
size
();
vec_in_dims
.
insert
(
vec_in_dims
.
begin
(),
diff
,
1
);
std
::
vector
<
int
>
repeat_times
(
vec_in_dims
.
size
());
for
(
size_t
i
=
0
;
i
<
vec_in_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_NE
(
expand_shape
[
i
],
0
,
...
...
@@ -301,12 +269,11 @@ void tensor_expand(const framework::ExecutionContext& context,
out_dims
[
i
]
*=
repeat_times
[
i
];
}
out0
->
Resize
(
out_dims
);
auto
x
=
EigenTensor
<
T
,
Rank
>::
From
(
arg1
,
new_in_dims
);
out0
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
y
=
EigenTensor
<
T
,
Rank
>::
From
(
*
out0
,
out_dims
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
out
->
Resize
(
out_dims
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x
=
EigenTensor
<
T
,
Rank
>::
From
(
in
,
new_in_dims
);
auto
y
=
EigenTensor
<
T
,
Rank
>::
From
(
*
out
,
out_dims
);
auto
&
place
=
*
context
.
eigen_device
();
// use 32-bit index to speed up
bool
use_32bit_index
=
y
.
size
()
<
Eigen
::
NumTraits
<
int
>::
highest
();
if
(
use_32bit_index
)
{
...
...
@@ -318,6 +285,41 @@ void tensor_expand(const framework::ExecutionContext& context,
}
}
template
<
typename
T
,
typename
DeviceContext
>
void
TensorExpand
(
const
DeviceContext
&
context
,
const
Tensor
&
in
,
Tensor
*
out
,
const
std
::
vector
<
int64_t
>&
expand_shape
)
{
// necessary check before expand operation
PADDLE_ENFORCE_GE
(
expand_shape
.
size
(),
in
.
dims
().
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of 'expand_shape' (%d) should >= the input "
"Tensor's rank (%d)."
,
expand_shape
.
size
(),
in
.
dims
().
size
()));
PADDLE_ENFORCE_LE
(
expand_shape
.
size
(),
MAX_RANK_SUPPORTED
,
platform
::
errors
::
InvalidArgument
(
"The size of 'expand_shape' (%d) should be <= %d"
,
expand_shape
.
size
(),
MAX_RANK_SUPPORTED
));
switch
(
expand_shape
.
size
())
{
case
1
:
expand_impl
<
1
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
case
2
:
expand_impl
<
2
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
case
3
:
expand_impl
<
3
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
case
4
:
expand_impl
<
4
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
case
5
:
expand_impl
<
5
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
case
6
:
expand_impl
<
6
,
T
,
DeviceContext
>
(
context
,
in
,
out
,
expand_shape
);
break
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
static
void
linalg_solve
(
const
framework
::
ExecutionContext
&
context
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
...
...
@@ -356,69 +358,11 @@ static void linalg_solve(const framework::ExecutionContext& context,
std
::
tie
(
x_broadcast_dims
,
y_broadcast_dims
)
=
get_broadcast_dims
(
tmp_x
,
tmp_y
);
expand_check
(
tmp_x
,
x_broadcast_dims
);
expand_check
(
tmp_y
,
y_broadcast_dims
);
Tensor
tmp_x_bc
;
Tensor
tmp_y_bc
;
auto
tmp_x_rank
=
tmp_x
.
dims
().
size
();
auto
tmp_y_rank
=
tmp_y
.
dims
().
size
();
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
auto
rank_0
=
std
::
max
(
tmp_x_rank
,
static_cast
<
int
>
(
x_broadcast_dims
.
size
()));
switch
(
rank_0
)
{
case
1
:
tensor_expand
<
1
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
case
2
:
tensor_expand
<
2
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
case
3
:
tensor_expand
<
3
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
case
4
:
tensor_expand
<
4
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
case
5
:
tensor_expand
<
5
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
case
6
:
tensor_expand
<
6
,
T
,
DeviceContext
>
(
context
,
tmp_x
,
&
tmp_x_bc
,
x_broadcast_dims
);
break
;
}
auto
rank_1
=
std
::
max
(
tmp_y_rank
,
static_cast
<
int
>
(
y_broadcast_dims
.
size
()));
switch
(
rank_1
)
{
case
1
:
tensor_expand
<
1
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
case
2
:
tensor_expand
<
2
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
case
3
:
tensor_expand
<
3
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
case
4
:
tensor_expand
<
4
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
case
5
:
tensor_expand
<
5
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
case
6
:
tensor_expand
<
6
,
T
,
DeviceContext
>
(
context
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
break
;
}
Tensor
tmp_y_bc
;
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
tmp_y
,
&
tmp_y_bc
,
y_broadcast_dims
);
auto
x_dim
=
x
->
dims
();
auto
y_dim
=
y
->
dims
();
...
...
@@ -658,7 +602,7 @@ class SolveGradKernel : public framework::OpKernel<T> {
if
(
dy_help
.
dims
().
size
()
!=
dy
->
dims
().
size
())
{
keep_dim
=
false
;
}
ReduceSumForSolve
Grad
<
DeviceContext
,
T
>
(
&
dy_help
,
dy
,
dy_reduce_dims
,
ReduceSumForSolve
<
DeviceContext
,
T
>
(
&
dy_help
,
dy
,
dy_reduce_dims
,
keep_dim
,
ctx
);
}
dy
->
Resize
(
y
->
dims
());
...
...
@@ -708,7 +652,7 @@ class SolveGradKernel : public framework::OpKernel<T> {
if
(
dx_help
.
dims
().
size
()
!=
dx
->
dims
().
size
())
{
keep_dim
=
false
;
}
ReduceSumForSolve
Grad
<
DeviceContext
,
T
>
(
&
dx_help
,
dx
,
dx_reduce_dims
,
ReduceSumForSolve
<
DeviceContext
,
T
>
(
&
dx_help
,
dx
,
dx_reduce_dims
,
keep_dim
,
ctx
);
}
dx
->
Resize
(
input
->
dims
());
...
...
paddle/fluid/operators/triangular_solve_op.cc
0 → 100644
浏览文件 @
3a81805b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/operators/solve_op.h"
namespace
paddle
{
namespace
operators
{
class
TriangularSolveOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"TriangularSolve"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"TriangularSolve"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"TriangularSolve"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
x_dims_n
=
x_dims
.
size
();
auto
y_dims_n
=
y_dims
.
size
();
PADDLE_ENFORCE_GE
(
x_dims_n
,
2
,
platform
::
errors
::
InvalidArgument
(
"The input tensor X's dimensions of TriangularSolveOp "
"should be >= 2. But received X's "
"dimensions = %d, X's shape = [%s]"
,
x_dims
.
size
(),
x_dims
));
PADDLE_ENFORCE_GE
(
y_dims_n
,
2
,
platform
::
errors
::
InvalidArgument
(
"The input tensor Y's dimensions of TriangularSolveOp "
"should be >=2. But received Y's "
"dimensions = %d, Y's shape = [%s]"
,
y_dims
.
size
(),
y_dims
));
PADDLE_ENFORCE_EQ
(
x_dims
[
x_dims_n
-
2
],
x_dims
[
x_dims_n
-
1
],
platform
::
errors
::
InvalidArgument
(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d."
,
x_dims
[
x_dims_n
-
2
],
x_dims
[
x_dims_n
-
1
]));
std
::
vector
<
int64_t
>
x_dims_vec
=
paddle
::
framework
::
vectorize
(
x_dims
);
std
::
vector
<
int64_t
>
y_dims_vec
=
paddle
::
framework
::
vectorize
(
y_dims
);
std
::
vector
<
int64_t
>
x_dims_vec_cut
(
x_dims_vec
.
begin
(),
x_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
y_dims_vec_cut
(
y_dims_vec
.
begin
(),
y_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
expand_batch_portion
=
get_broadcast_batch_portion
(
x_dims_vec_cut
,
y_dims_vec_cut
);
std
::
vector
<
int64_t
>
y_broadcast_dims
({
expand_batch_portion
});
y_broadcast_dims
.
insert
(
y_broadcast_dims
.
end
(),
{
y_dims_vec
[
y_dims_n
-
2
],
y_dims_vec
[
y_dims_n
-
1
]});
// dim of 'Out' is the same with 'Y' after broadcast
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
y_broadcast_dims
));
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
TriangularSolveOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor), The first input tensor of triangular solve op, which "
"is the triangular coefficient matrix."
);
AddInput
(
"Y"
,
"(Tensor), The second input tensor of triangular solve op, which "
"is multiple right-hand."
);
AddOutput
(
"Out"
,
"(Tensor), The solution tensor of triangular solve op."
);
AddAttr
<
bool
>
(
"upper"
,
"whether to solve the upper-triangular or the "
"lower-triangular system of equations"
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"transpose"
,
"whether X should be transposed firstly."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"unitriangular"
,
"whether X is unit triangular."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
Triangular Solve Operator.
This operator is used to computes the solution of equations with a triangular coefficient matrix.
The equation is:
$$Out = X^-1 * Y$$
)DOC"
);
}
};
class
TriangularSolveOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
const
override
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Out"
}};
return
m
;
}
};
class
TriangularSolveGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"triangular_solve"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"triangular_solve"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"triangular_solve"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
"Out@GRAD"
,
"triangular_solve"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
}
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
SetOutputDim
(
y_grad_name
,
y_dims
);
}
}
};
template
<
typename
T
>
class
TriangularSolveOpGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
retv
)
const
override
{
retv
->
SetType
(
"triangular_solve_grad"
);
retv
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
retv
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
retv
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
retv
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
retv
->
SetOutput
(
framework
::
GradVarName
(
"Y"
),
this
->
InputGrad
(
"Y"
));
retv
->
SetAttrMap
(
this
->
Attrs
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
triangular_solve
,
ops
::
TriangularSolveOp
,
ops
::
TriangularSolveOpMaker
,
ops
::
TriangularSolveOpInferVarType
,
ops
::
TriangularSolveOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
TriangularSolveOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
triangular_solve_grad
,
ops
::
TriangularSolveGradOp
);
REGISTER_OP_CPU_KERNEL
(
triangular_solve
,
ops
::
TriangularSolveKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TriangularSolveKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
triangular_solve_grad
,
ops
::
TriangularSolveGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
TriangularSolveGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/triangular_solve_op.cu
0 → 100644
浏览文件 @
3a81805b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
MatrixReduceSumFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
Tensor
&
in
,
Tensor
*
out
,
const
framework
::
ExecutionContext
&
ctx
)
{
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const
std
::
vector
<
std
::
int64_t
>
in_dims
=
framework
::
vectorize
(
in
.
dims
());
auto
in_size
=
in_dims
.
size
();
const
std
::
vector
<
std
::
int64_t
>
out_dims
=
framework
::
vectorize
(
out
->
dims
());
auto
out_size
=
out_dims
.
size
();
std
::
vector
<
std
::
int64_t
>
out_bst_dims
(
in_size
);
std
::
fill
(
out_bst_dims
.
data
(),
out_bst_dims
.
data
()
+
in_size
-
out_size
,
1
);
std
::
copy
(
out_dims
.
data
(),
out_dims
.
data
()
+
out_size
,
out_bst_dims
.
data
()
+
in_size
-
out_size
);
std
::
vector
<
int
>
out_reduce_dims
;
for
(
size_t
idx
=
0
;
idx
<=
in_size
-
3
;
idx
++
)
{
if
(
in_dims
[
idx
]
!=
1
&&
out_bst_dims
[
idx
]
==
1
)
{
out_reduce_dims
.
push_back
(
idx
);
}
}
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceFunctorImpl
<
T
,
T
,
CustomSum
>
(
in
,
out
,
out_reduce_dims
,
stream
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
triangular_solve
,
ops
::
TriangularSolveKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TriangularSolveKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
triangular_solve_grad
,
ops
::
TriangularSolveGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
TriangularSolveGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/triangular_solve_op.h
0 → 100644
浏览文件 @
3a81805b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/tril_triu_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
static
void
triangular_solve
(
const
DeviceContext
&
context
,
const
Tensor
&
x
,
const
Tensor
&
y
,
Tensor
*
out
,
bool
upper
,
bool
transpose
,
bool
unitriangular
)
{
// Tensor broadcast use eigen
std
::
vector
<
int64_t
>
x_bst_dims_vec
;
std
::
vector
<
int64_t
>
y_bst_dims_vec
;
std
::
tie
(
x_bst_dims_vec
,
y_bst_dims_vec
)
=
get_broadcast_dims
(
x
,
y
);
Tensor
x_bst
(
x
.
type
());
TensorExpand
<
T
,
DeviceContext
>
(
context
,
x
,
&
x_bst
,
x_bst_dims_vec
);
Tensor
y_bst
(
y
.
type
());
TensorExpand
<
T
,
DeviceContext
>
(
context
,
y
,
&
y_bst
,
y_bst_dims_vec
);
// TriangularSolveFunctor performs calculations in-place
// x_clone should be a copy of 'x' after broadcast
// out should be a copy of 'y' after broadcast
Tensor
x_clone
(
x
.
type
());
x_clone
.
Resize
(
framework
::
make_ddim
(
x_bst_dims_vec
));
x_clone
.
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
TensorCopy
(
x_bst
,
context
.
GetPlace
(),
context
,
&
x_clone
);
out
->
Resize
(
framework
::
make_ddim
(
y_bst_dims_vec
));
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
TensorCopy
(
y_bst
,
context
.
GetPlace
(),
context
,
out
);
math
::
TriangularSolveFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
context
,
&
x_clone
,
out
,
/*left=*/
true
,
upper
,
transpose
,
unitriangular
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
MatrixReduceSumFunctor
{
public:
void
operator
()(
const
Tensor
&
input
,
Tensor
*
output
,
const
framework
::
ExecutionContext
&
ctx
);
};
template
<
typename
T
>
class
MatrixReduceSumFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
Tensor
&
in
,
Tensor
*
out
,
const
framework
::
ExecutionContext
&
ctx
)
{
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const
std
::
vector
<
std
::
int64_t
>
in_dims
=
framework
::
vectorize
(
in
.
dims
());
auto
in_size
=
in_dims
.
size
();
const
std
::
vector
<
std
::
int64_t
>
out_dims
=
framework
::
vectorize
(
out
->
dims
());
auto
out_size
=
out_dims
.
size
();
std
::
vector
<
std
::
int64_t
>
out_bst_dims
(
in_size
);
std
::
fill
(
out_bst_dims
.
data
(),
out_bst_dims
.
data
()
+
in_size
-
out_size
,
1
);
std
::
copy
(
out_dims
.
data
(),
out_dims
.
data
()
+
out_size
,
out_bst_dims
.
data
()
+
in_size
-
out_size
);
out
->
Resize
(
framework
::
make_ddim
(
out_bst_dims
));
std
::
vector
<
int
>
out_reduce_dims
;
for
(
size_t
idx
=
0
;
idx
<=
in_size
-
3
;
idx
++
)
{
if
(
in_dims
[
idx
]
!=
1
&&
out_bst_dims
[
idx
]
==
1
)
{
out_reduce_dims
.
push_back
(
idx
);
}
}
ReduceKernelFunctor
<
platform
::
CPUDeviceContext
,
T
,
SumFunctor
>
(
&
in
,
out
,
out_reduce_dims
,
true
,
false
,
ctx
)
.
template
apply
<
T
>();
out
->
Resize
(
framework
::
make_ddim
(
out_dims
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
TriangularSolveKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
auto
*
y
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
bool
upper
=
ctx
.
template
Attr
<
bool
>(
"upper"
);
bool
transpose
=
ctx
.
template
Attr
<
bool
>(
"transpose"
);
bool
unitriangular
=
ctx
.
template
Attr
<
bool
>(
"unitriangular"
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
*
x
,
*
y
,
out
,
upper
,
transpose
,
unitriangular
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
TriangularSolveGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
auto
*
y
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
const
auto
*
out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Out"
);
const
auto
*
dout
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
bool
upper
=
ctx
.
template
Attr
<
bool
>(
"upper"
);
bool
transpose
=
ctx
.
template
Attr
<
bool
>(
"transpose"
);
bool
unitriangular
=
ctx
.
template
Attr
<
bool
>(
"unitriangular"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int64_t
>
x_bst_dims_vec
;
std
::
vector
<
int64_t
>
y_bst_dims_vec
;
std
::
tie
(
x_bst_dims_vec
,
y_bst_dims_vec
)
=
get_broadcast_dims
(
*
x
,
*
y
);
Tensor
dy_bst
(
y
->
type
());
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
y
->
dims
(),
dev_ctx
.
GetPlace
());
dy_bst
.
Resize
(
framework
::
make_ddim
(
y_bst_dims_vec
));
dy_bst
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
// calculate x's conjugate for complex
Tensor
x_conj
(
x
->
type
());
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
x
->
numel
());
math
::
ConjFunctor
<
T
>
x_functor
(
x
->
data
<
T
>
(),
x
->
numel
(),
x_conj
.
mutable_data
<
T
>
(
x
->
dims
(),
dev_ctx
.
GetPlace
()));
x_for_range
(
x_functor
);
// reuse forward to get dy_bst, and the result has been broadcated.
triangular_solve
<
DeviceContext
,
T
>
(
dev_ctx
,
x_conj
,
*
dout
,
&
dy_bst
,
upper
,
!
transpose
,
unitriangular
);
if
(
dy_bst
.
dims
()
==
dy
->
dims
())
{
framework
::
TensorCopy
(
dy_bst
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
dy
);
}
else
{
MatrixReduceSumFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
dy_bst
,
dy
,
ctx
);
dy
->
Resize
(
y
->
dims
());
}
}
Tensor
dx_bst
(
x
->
type
());
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
x
->
dims
(),
dev_ctx
.
GetPlace
());
dx_bst
.
Resize
(
framework
::
make_ddim
(
x_bst_dims_vec
));
dx_bst
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
// calculate out's conjugate for complex
Tensor
out_conj
(
out
->
type
());
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
math
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out_conj
.
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
if
(
transpose
)
{
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
out_conj
.
dims
(),
0
,
false
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
dy_bst
.
dims
(),
0
,
true
);
blas
.
MatMul
(
out_conj
,
mat_dim_a
,
dy_bst
,
mat_dim_b
,
static_cast
<
T
>
(
-
1
),
&
dx_bst
,
static_cast
<
T
>
(
0
));
}
else
{
auto
mat_dim_a
=
math
::
CreateMatrixDescriptor
(
dy_bst
.
dims
(),
0
,
false
);
auto
mat_dim_b
=
math
::
CreateMatrixDescriptor
(
out_conj
.
dims
(),
0
,
true
);
blas
.
MatMul
(
dy_bst
,
mat_dim_a
,
out_conj
,
mat_dim_b
,
static_cast
<
T
>
(
-
1
),
&
dx_bst
,
static_cast
<
T
>
(
0
));
}
Tensor
dx_bst_upper
(
x
->
type
());
// get upper or lower triangular
dx_bst_upper
.
Resize
(
dx_bst
.
dims
());
dx_bst_upper
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
const
auto
&
dims
=
dx_bst
.
dims
();
const
auto
H
=
dims
[
dims
.
size
()
-
2
];
const
auto
W
=
dims
[
dims
.
size
()
-
1
];
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
dx_bst
.
numel
());
TrilTriuCompute
<
T
>
tril_triu_computer
(
dx_bst
.
data
<
T
>
(),
unitriangular
,
!
upper
,
H
,
W
,
dx_bst_upper
.
data
<
T
>
());
x_for_range
(
tril_triu_computer
);
if
(
dx_bst_upper
.
dims
()
==
dx
->
dims
())
{
framework
::
TensorCopy
(
dx_bst_upper
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
dx
);
}
else
{
MatrixReduceSumFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
dx_bst_upper
,
dx
,
ctx
);
dx
->
Resize
(
x
->
dims
());
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/dynload/cublas.h
浏览文件 @
3a81805b
...
...
@@ -75,6 +75,8 @@ extern void *cublas_dso_handle;
__macro(cublasDgeam); \
__macro(cublasStrsm_v2); \
__macro(cublasDtrsm_v2); \
__macro(cublasCtrsm_v2); \
__macro(cublasZtrsm_v2); \
__macro(cublasCreate_v2); \
__macro(cublasDestroy_v2); \
__macro(cublasSetStream_v2); \
...
...
@@ -84,6 +86,10 @@ extern void *cublas_dso_handle;
__macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \
__macro(cublasStrsmBatched); \
__macro(cublasDtrsmBatched); \
__macro(cublasCtrsmBatched); \
__macro(cublasZtrsmBatched); \
__macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
...
...
paddle/fluid/platform/dynload/mklml.h
浏览文件 @
3a81805b
...
...
@@ -25,7 +25,7 @@ namespace platform {
namespace
dynload
{
extern
std
::
once_flag
mklml_dso_flag
;
extern
void
*
mklml_dso_handle
;
extern
void
*
mklml_dso_handle
;
/**
* The following macro definition can generate structs
...
...
@@ -40,7 +40,7 @@ extern void* mklml_dso_handle;
std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
}); \
static void
*
p_##_name = dlsym(mklml_dso_handle, #__name); \
static void
*
p_##_name = dlsym(mklml_dso_handle, #__name); \
return reinterpret_cast<mklmlFunc>(p_##_name)(args...); \
} \
}; \
...
...
@@ -67,6 +67,8 @@ extern void* mklml_dso_handle;
__macro(cblas_zgemv); \
__macro(cblas_strsm); \
__macro(cblas_dtrsm); \
__macro(cblas_ctrsm); \
__macro(cblas_ztrsm); \
__macro(cblas_sgemm_alloc); \
__macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
...
...
python/paddle/fluid/tests/unittests/test_triangular_solve_op.py
0 → 100644
浏览文件 @
3a81805b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.w
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
import
paddle
from
op_test
import
OpTest
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
,
core
paddle
.
enable_static
()
# 2D + 2D , test 'upper'
class
TestTriangularSolveOp
(
OpTest
):
"""
case 1
"""
def
config
(
self
):
self
.
x_shape
=
[
12
,
12
]
self
.
y_shape
=
[
12
,
10
]
self
.
upper
=
True
self
.
transpose
=
False
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
self
.
output
=
np
.
linalg
.
solve
(
np
.
triu
(
self
.
inputs
[
'X'
]),
self
.
inputs
[
'Y'
])
def
setUp
(
self
):
self
.
op_type
=
"triangular_solve"
self
.
config
()
self
.
inputs
=
{
'X'
:
np
.
random
.
random
(
self
.
x_shape
).
astype
(
self
.
dtype
),
'Y'
:
np
.
random
.
random
(
self
.
y_shape
).
astype
(
self
.
dtype
)
}
self
.
attrs
=
{
'upper'
:
self
.
upper
,
'transpose'
:
self
.
transpose
,
'unitriangular'
:
self
.
unitriangular
,
}
self
.
set_output
()
self
.
outputs
=
{
'Out'
:
self
.
output
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
)
# 2D(broadcast) + 3D, test 'transpose'
class
TestTriangularSolveOp2
(
TestTriangularSolveOp
):
"""
case 2
"""
def
config
(
self
):
self
.
x_shape
=
[
10
,
10
]
self
.
y_shape
=
[
3
,
10
,
8
]
self
.
upper
=
False
self
.
transpose
=
True
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
tril
(
self
.
inputs
[
'X'
]).
transpose
(
1
,
0
)
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 3D(broadcast) + 3D
class
TestTriangularSolveOp3
(
TestTriangularSolveOp
):
"""
case 3
"""
def
config
(
self
):
self
.
x_shape
=
[
1
,
10
,
10
]
self
.
y_shape
=
[
6
,
10
,
12
]
self
.
upper
=
False
self
.
transpose
=
False
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
tril
(
self
.
inputs
[
'X'
])
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 3D + 3D(broadcast), test 'transpose'
class
TestTriangularSolveOp4
(
TestTriangularSolveOp
):
"""
case 4
"""
def
config
(
self
):
self
.
x_shape
=
[
3
,
10
,
10
]
self
.
y_shape
=
[
1
,
10
,
12
]
self
.
upper
=
True
self
.
transpose
=
True
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
triu
(
self
.
inputs
[
'X'
]).
transpose
(
0
,
2
,
1
)
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 2D + 2D , test 'unitriangular' specially
class
TestTriangularSolveOp5
(
TestTriangularSolveOp
):
"""
case 5
"""
def
config
(
self
):
self
.
x_shape
=
[
10
,
10
]
self
.
y_shape
=
[
10
,
10
]
self
.
upper
=
True
self
.
transpose
=
False
self
.
unitriangular
=
True
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
triu
(
self
.
inputs
[
'X'
])
np
.
fill_diagonal
(
x
,
1.
)
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
def
test_check_grad_normal
(
self
):
x
=
np
.
triu
(
self
.
inputs
[
'X'
])
np
.
fill_diagonal
(
x
,
1.
)
grad_out
=
np
.
ones
([
10
,
10
]).
astype
(
'float64'
)
grad_y
=
np
.
linalg
.
solve
(
x
.
transpose
(
1
,
0
),
grad_out
)
grad_x
=
-
np
.
matmul
(
grad_y
,
self
.
output
.
transpose
(
1
,
0
))
grad_x
=
np
.
triu
(
grad_x
)
np
.
fill_diagonal
(
grad_x
,
0.
)
self
.
check_grad
(
[
'X'
,
'Y'
],
'Out'
,
user_defined_grads
=
[
grad_x
,
grad_y
],
user_defined_grad_outputs
=
[
grad_out
])
# 4D(broadcast) + 4D(broadcast)
class
TestTriangularSolveOp6
(
TestTriangularSolveOp
):
"""
case 6
"""
def
config
(
self
):
self
.
x_shape
=
[
1
,
3
,
10
,
10
]
self
.
y_shape
=
[
2
,
1
,
10
,
5
]
self
.
upper
=
False
self
.
transpose
=
False
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
tril
(
self
.
inputs
[
'X'
])
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 3D(broadcast) + 4D(broadcast), test 'upper'
class
TestTriangularSolveOp7
(
TestTriangularSolveOp
):
"""
case 7
"""
def
config
(
self
):
self
.
x_shape
=
[
2
,
10
,
10
]
self
.
y_shape
=
[
5
,
1
,
10
,
2
]
self
.
upper
=
True
self
.
transpose
=
True
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
triu
(
self
.
inputs
[
'X'
]).
transpose
(
0
,
2
,
1
)
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 3D(broadcast) + 5D
class
TestTriangularSolveOp8
(
TestTriangularSolveOp
):
"""
case 8
"""
def
config
(
self
):
self
.
x_shape
=
[
12
,
3
,
3
]
self
.
y_shape
=
[
2
,
3
,
12
,
3
,
2
]
self
.
upper
=
False
self
.
transpose
=
False
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
tril
(
self
.
inputs
[
'X'
])
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
linalg
.
solve
(
x
,
y
)
# 5D + 4D(broadcast)
class
TestTriangularSolveOp9
(
TestTriangularSolveOp
):
"""
case 9
"""
def
config
(
self
):
self
.
x_shape
=
[
2
,
4
,
2
,
3
,
3
]
self
.
y_shape
=
[
4
,
1
,
3
,
10
]
self
.
upper
=
False
self
.
transpose
=
False
self
.
unitriangular
=
False
self
.
dtype
=
"float64"
def
set_output
(
self
):
x
=
np
.
tril
(
self
.
inputs
[
'X'
])
y
=
self
.
inputs
[
'Y'
]
self
.
output
=
np
.
matmul
(
np
.
linalg
.
inv
(
x
),
y
)
class
TestTriangularSolveAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
np
.
random
.
seed
(
2021
)
self
.
place
=
[
paddle
.
CPUPlace
()]
self
.
dtype
=
"float64"
if
core
.
is_compiled_with_cuda
():
self
.
place
.
append
(
paddle
.
CUDAPlace
(
0
))
def
check_static_result
(
self
,
place
):
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
x
=
fluid
.
data
(
name
=
"x"
,
shape
=
[
3
,
3
],
dtype
=
self
.
dtype
)
y
=
fluid
.
data
(
name
=
"y"
,
shape
=
[
3
,
2
],
dtype
=
self
.
dtype
)
z
=
paddle
.
linalg
.
triangular_solve
(
x
,
y
)
x_np
=
np
.
random
.
random
([
3
,
3
]).
astype
(
self
.
dtype
)
y_np
=
np
.
random
.
random
([
3
,
2
]).
astype
(
self
.
dtype
)
z_np
=
np
.
linalg
.
solve
(
np
.
triu
(
x_np
),
y_np
)
exe
=
fluid
.
Executor
(
place
)
fetches
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"x"
:
x_np
,
"y"
:
y_np
},
fetch_list
=
[
z
])
self
.
assertTrue
(
np
.
allclose
(
fetches
[
0
],
z_np
))
def
test_static
(
self
):
for
place
in
self
.
place
:
self
.
check_static_result
(
place
=
place
)
def
test_dygraph
(
self
):
def
run
(
place
):
paddle
.
disable_static
(
place
)
x_np
=
np
.
random
.
random
([
3
,
3
]).
astype
(
self
.
dtype
)
y_np
=
np
.
random
.
random
([
3
,
2
]).
astype
(
self
.
dtype
)
z_np
=
np
.
linalg
.
solve
(
np
.
tril
(
x_np
),
y_np
)
x
=
paddle
.
to_tensor
(
x_np
)
y
=
paddle
.
to_tensor
(
y_np
)
z
=
paddle
.
linalg
.
triangular_solve
(
x
,
y
,
upper
=
False
)
self
.
assertTrue
(
np
.
allclose
(
z_np
,
z
.
numpy
()))
self
.
assertEqual
(
z_np
.
shape
,
z
.
numpy
().
shape
)
paddle
.
enable_static
()
for
place
in
self
.
place
:
run
(
place
)
class
TestTriangularSolveOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
# The input type of solve_op must be Variable.
x1
=
fluid
.
create_lod_tensor
(
np
.
array
([[
-
1
]]),
[[
1
]],
fluid
.
CPUPlace
())
y1
=
fluid
.
create_lod_tensor
(
np
.
array
([[
-
1
]]),
[[
1
]],
fluid
.
CPUPlace
())
self
.
assertRaises
(
TypeError
,
paddle
.
linalg
.
triangular_solve
,
x1
,
y1
)
# The data type of input must be float32 or float64.
x2
=
fluid
.
data
(
name
=
"x2"
,
shape
=
[
30
,
30
],
dtype
=
"bool"
)
y2
=
fluid
.
data
(
name
=
"y2"
,
shape
=
[
30
,
10
],
dtype
=
"bool"
)
self
.
assertRaises
(
TypeError
,
paddle
.
linalg
.
triangular_solve
,
x2
,
y2
)
x3
=
fluid
.
data
(
name
=
"x3"
,
shape
=
[
30
,
30
],
dtype
=
"int32"
)
y3
=
fluid
.
data
(
name
=
"y3"
,
shape
=
[
30
,
10
],
dtype
=
"int32"
)
self
.
assertRaises
(
TypeError
,
paddle
.
linalg
.
triangular_solve
,
x3
,
y3
)
x4
=
fluid
.
data
(
name
=
"x4"
,
shape
=
[
30
,
30
],
dtype
=
"float16"
)
y4
=
fluid
.
data
(
name
=
"y4"
,
shape
=
[
30
,
10
],
dtype
=
"float16"
)
self
.
assertRaises
(
TypeError
,
paddle
.
linalg
.
triangular_solve
,
x4
,
y4
)
# The number of dimensions of input'X must be >= 2.
x5
=
fluid
.
data
(
name
=
"x5"
,
shape
=
[
30
],
dtype
=
"float64"
)
y5
=
fluid
.
data
(
name
=
"y5"
,
shape
=
[
30
,
30
],
dtype
=
"float64"
)
self
.
assertRaises
(
ValueError
,
paddle
.
linalg
.
triangular_solve
,
x5
,
y5
)
# The number of dimensions of input'Y must be >= 2.
x6
=
fluid
.
data
(
name
=
"x6"
,
shape
=
[
30
,
30
],
dtype
=
"float64"
)
y6
=
fluid
.
data
(
name
=
"y6"
,
shape
=
[
30
],
dtype
=
"float64"
)
self
.
assertRaises
(
ValueError
,
paddle
.
linalg
.
triangular_solve
,
x6
,
y6
)
# The inner-most 2 dimensions of input'X should be equal to each other
x7
=
fluid
.
data
(
name
=
"x7"
,
shape
=
[
2
,
3
,
4
],
dtype
=
"float64"
)
y7
=
fluid
.
data
(
name
=
"y7"
,
shape
=
[
2
,
4
,
3
],
dtype
=
"float64"
)
self
.
assertRaises
(
ValueError
,
paddle
.
linalg
.
triangular_solve
,
x7
,
y7
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/linalg.py
浏览文件 @
3a81805b
...
...
@@ -29,6 +29,7 @@ from .tensor.linalg import eigvalsh
from
.tensor.linalg
import
det
from
.tensor.linalg
import
slogdet
from
.tensor.linalg
import
pinv
from
.tensor.linalg
import
triangular_solve
__all__
=
[
'cholesky'
,
#noqa
...
...
@@ -47,5 +48,6 @@ __all__ = [
'eigh'
,
'eigvalsh'
,
'pinv'
,
'solve'
'solve'
,
'triangular_solve'
,
]
python/paddle/tensor/__init__.py
浏览文件 @
3a81805b
...
...
@@ -397,6 +397,7 @@ tensor_method_func = [ #noqa
'uniform_'
,
'multi_dot'
,
'solve'
,
'triangular_solve'
]
#this list used in math_op_patch.py for magic_method bind
...
...
python/paddle/tensor/linalg.py
浏览文件 @
3a81805b
...
...
@@ -2315,6 +2315,79 @@ def solve(x, y, name=None):
return
out
def
triangular_solve
(
x
,
y
,
upper
=
True
,
transpose
=
False
,
unitriangular
=
False
,
name
=
None
):
r
"""
Computes the solution of a system of equations with a triangular coefficient matrix `x` and
multiple right-hand sides `y` .
Input `x` and `y` is 2D matrices or batches of 2D matrices. If the inputs are batches, the outputs
is also batches.
Args:
x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or
more batch dimensions. Its data type should be float32 or float64.
y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is
zero or more batch dimensions. Its data type should be float32 or float64.
upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular
system of equations. Default: True.
transpose (bool, optional): whether `x` should be transposed before calculation. Default: False.
unitriangular (bool, optional): whether `x` is unit triangular. If True, the diagonal elements of `x` are assumed
to be 1 and not referenced from `x` . Default: False.
name(str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The solution of the system of equations. Its data type should be the same as that of `x`.
Examples:
.. code-block:: python
# a square system of linear equations:
# x1 + x2 + x3 = 0
# 2*x2 + x3 = -9
# -x3 = 5
import paddle
import numpy as np
x = paddle.to_tensor([[1, 1, 1],
[0, 2, 1],
[0, 0,-1]], dtype="float64")
y = paddle.to_tensor([[0], [-9], [5]], dtype="float64")
out = paddle.linalg.triangular_solve(x, y, upper=True)
print(out)
# [7, -2, -5]
"""
if
in_dygraph_mode
():
return
_C_ops
.
triangular_solve
(
x
,
y
,
'upper'
,
upper
,
'transpose'
,
transpose
,
'unitriangular'
,
unitriangular
)
inputs
=
{
"X"
:
[
x
],
"Y"
:
[
y
]}
helper
=
LayerHelper
(
"triangular_solve"
,
**
locals
())
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
'triangular_solve'
)
check_variable_and_dtype
(
y
,
'y'
,
[
'float32'
,
'float64'
],
'triangular_solve'
)
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'triangular_solve'
,
inputs
=
{
'X'
:
x
,
'Y'
:
y
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'upper'
:
upper
,
'transpose'
:
transpose
,
'unitriangular'
:
unitriangular
})
return
out
def
eigvalsh
(
x
,
UPLO
=
'L'
,
name
=
None
):
"""
Computes the eigenvalues of a
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录