Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ae8ca764
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae8ca764
编写于
7月 13, 2022
作者:
W
Weilong Wu
提交者:
GitHub
7月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Migrate matrix_solve to phi (#44298)
* [Phi] Migrate matrix_solve to phi * replace mutable_data with Alloc
上级
988abd6a
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
302 addition
and
282 deletion
+302
-282
paddle/fluid/operators/eig_op.h
paddle/fluid/operators/eig_op.h
+2
-2
paddle/fluid/operators/lstsq_op.h
paddle/fluid/operators/lstsq_op.h
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+0
-1
paddle/fluid/operators/math/matrix_solve.cu.cc
paddle/fluid/operators/math/matrix_solve.cu.cc
+0
-189
paddle/fluid/operators/solve_op.h
paddle/fluid/operators/solve_op.h
+4
-65
paddle/phi/kernels/funcs/CMakeLists.txt
paddle/phi/kernels/funcs/CMakeLists.txt
+1
-0
paddle/phi/kernels/funcs/matrix_solve.cc
paddle/phi/kernels/funcs/matrix_solve.cc
+32
-0
paddle/phi/kernels/funcs/matrix_solve.cu
paddle/phi/kernels/funcs/matrix_solve.cu
+178
-0
paddle/phi/kernels/funcs/matrix_solve.h
paddle/phi/kernels/funcs/matrix_solve.h
+84
-24
未找到文件。
paddle/fluid/operators/eig_op.h
浏览文件 @
ae8ca764
...
@@ -19,7 +19,6 @@
...
@@ -19,7 +19,6 @@
#include <algorithm>
#include <algorithm>
#include <complex>
#include <complex>
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
...
@@ -30,6 +29,7 @@
...
@@ -30,6 +29,7 @@
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
...
@@ -366,7 +366,7 @@ void ComputeBackwardForComplexInput(
...
@@ -366,7 +366,7 @@ void ComputeBackwardForComplexInput(
int
k
=
rhs
.
dims
()[
rhs
.
dims
().
size
()
-
1
];
int
k
=
rhs
.
dims
()[
rhs
.
dims
().
size
()
-
1
];
auto
*
matrix_data
=
Vh
.
data
<
T
>
();
auto
*
matrix_data
=
Vh
.
data
<
T
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
>
();
auto
*
rhs_data
=
rhs
.
data
<
T
>
();
math
::
SolveLinearSystem
<
T
>
(
phi
::
funcs
::
SolveLinearSystem
<
T
>
(
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
matrix_data
,
rhs_data
,
x_grad_data
,
m
,
k
,
batch_count
);
}
}
...
...
paddle/fluid/operators/lstsq_op.h
浏览文件 @
ae8ca764
...
@@ -21,13 +21,13 @@
...
@@ -21,13 +21,13 @@
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#define EPSILON 1e-6
#define EPSILON 1e-6
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
ae8ca764
...
@@ -54,7 +54,6 @@ math_library(vol2col)
...
@@ -54,7 +54,6 @@ math_library(vol2col)
math_library
(
prelu
)
math_library
(
prelu
)
math_library
(
bert_encoder_functor
)
math_library
(
bert_encoder_functor
)
math_library
(
tree2col DEPS math_function
)
math_library
(
tree2col DEPS math_function
)
math_library
(
matrix_solve
)
cc_test
(
cc_test
(
selected_rows_functor_test
selected_rows_functor_test
...
...
paddle/fluid/operators/math/matrix_solve.cu.cc
已删除
100644 → 0
浏览文件 @
988abd6a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
platform
{
class
CUDADeviceContext
;
}
// namespace platform
}
// namespace paddle
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
DeviceContext
,
typename
T
>
class
MatrixSolveFunctor
;
template
<
typename
T
>
class
MatrixSolveFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
a
,
const
framework
::
Tensor
&
b
,
framework
::
Tensor
*
out
)
{
#ifndef PADDLE_WITH_HIP
// solve the equation: Ax = B,
// use cuBlas cublas<S/D>getrfBatched funcion to performs the LU
// factorization of each matrix A,
// and then use cuBlas cublas<S/D>getriBatched function to solve the
// equation after LU factorization.
// ref:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
const
auto
&
a_dims
=
a
.
dims
();
const
int
a_rank
=
a_dims
.
size
();
int
n
=
a_dims
[
a_rank
-
1
];
int
lda
=
n
;
int
batch_size
=
a_rank
>
2
?
a
.
numel
()
/
(
n
*
n
)
:
1
;
const
auto
&
b_dims
=
b
.
dims
();
const
int
b_rank
=
b_dims
.
size
();
int
nrhs
=
b_dims
[
b_rank
-
1
];
int
ldb
=
b_dims
[
b_rank
-
2
];
// make sure the out dims is right
out
->
Resize
(
b_dims
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// copy input A to a temporary tensor tmp_a,
// LU factorization, written back to original matrix A, so in the beginning,
// it's necessary to create a temporary tensor tmp_a.
Tensor
tmp_a
(
a
.
dtype
());
tmp_a
.
Resize
(
a
.
dims
());
tmp_a
.
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
TensorCopy
(
a
,
context
.
GetPlace
(),
&
tmp_a
);
// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar.
Tensor
tmp_b
(
b
.
type
());
const
auto
&
new_dims_vec
=
getNewDimsVec
(
b_dims
);
tmp_b
.
Resize
(
phi
::
make_ddim
(
new_dims_vec
));
tmp_b
.
mutable_data
<
T
>
(
context
.
GetPlace
());
phi
::
funcs
::
TransposeNormal
<
platform
::
CUDADeviceContext
,
T
>
trans
;
std
::
vector
<
int
>
new_axis
=
getNewAxis
(
b_rank
);
trans
(
context
,
b
,
&
tmp_b
,
new_axis
);
const
T
*
a_data_in_gpu
=
tmp_a
.
data
<
T
>
();
const
T
*
b_data_in_gpu
=
tmp_b
.
data
<
T
>
();
std
::
vector
<
const
T
*>
cpu_ptrs
(
batch_size
*
2
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cpu_ptrs
[
i
]
=
a_data_in_gpu
+
i
*
n
*
n
;
cpu_ptrs
[
i
+
batch_size
]
=
b_data_in_gpu
+
i
*
n
*
nrhs
;
}
// Copy the addresses of A and tmp_b from host to device.
memory
::
allocation
::
AllocationPtr
tmp_gpu_ptrs_data
=
memory
::
Alloc
(
context
,
cpu_ptrs
.
size
()
*
sizeof
(
T
*
));
memory
::
Copy
(
context
.
GetPlace
(),
tmp_gpu_ptrs_data
->
ptr
(),
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
cpu_ptrs
.
data
()),
cpu_ptrs
.
size
()
*
sizeof
(
T
*
),
context
.
stream
());
T
**
gpu_tmp_b_ptrs
=
reinterpret_cast
<
T
**>
(
tmp_gpu_ptrs_data
->
ptr
())
+
batch_size
;
// Allocate device memory for BatchedGETRF's info and pivots.
int
num_ints
=
n
<
32
?
batch_size
:
batch_size
*
(
n
+
1
);
memory
::
allocation
::
AllocationPtr
tmp_gpu_info_data
=
memory
::
Alloc
(
context
,
num_ints
*
sizeof
(
int
));
int
*
gpu_info_ptr
=
reinterpret_cast
<
int
*>
(
tmp_gpu_info_data
->
ptr
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
context
);
// only for singular checking
std
::
vector
<
int
>
info
;
info
.
resize
(
batch_size
);
int
*
gpu_pivot_ptr
=
reinterpret_cast
<
int
*>
(
tmp_gpu_info_data
->
ptr
())
+
batch_size
;
// This function performs the LU factorization of each matrix A by the
// equation A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas
.
BatchedGETRF
(
n
,
reinterpret_cast
<
T
**>
(
tmp_gpu_ptrs_data
->
ptr
()),
gpu_pivot_ptr
,
gpu_info_ptr
,
batch_size
);
// check whether BatchedGETRF is executed successfully or not
memory
::
Copy
(
platform
::
CPUPlace
(),
info
.
data
(),
context
.
GetPlace
(),
gpu_info_ptr
,
sizeof
(
int
)
*
batch_size
,
context
.
stream
());
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
info
[
i
],
0
,
platform
::
errors
::
PreconditionNotMet
(
"For batch [%d]: U(%d, %d) is zero, singular U. "
"Please check the matrix value and change it to a "
"non-singular matrix"
,
i
,
info
[
i
],
info
[
i
]));
}
// hold the result code from BatchedGETRS
int
host_info
=
0
;
// to solve the equation after LU factorization
CBLAS_TRANSPOSE
transA
=
CblasTrans
;
blas
.
BatchedGETRS
(
transA
,
n
,
nrhs
,
reinterpret_cast
<
const
T
**>
(
tmp_gpu_ptrs_data
->
ptr
()),
lda
,
gpu_pivot_ptr
,
gpu_tmp_b_ptrs
,
ldb
,
&
host_info
,
batch_size
);
// check whether BatchedGETRS is executed successfully or not
PADDLE_ENFORCE_EQ
(
host_info
,
0
,
platform
::
errors
::
InvalidArgument
(
"The [%d]'th argument to cublas*getrsBatched had "
"an illegal value."
,
-
host_info
));
// transpose tmp_b to get the final result in row-major form.
phi
::
funcs
::
TransposeNormal
<
platform
::
CUDADeviceContext
,
T
>
trans2
;
trans2
(
context
,
tmp_b
,
out
,
new_axis
);
#else
compute_solve_eigen
<
platform
::
CUDADeviceContext
,
T
>
(
context
,
a
,
b
,
out
);
#endif
}
};
template
class
MatrixSolveFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
MatrixSolveFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/solve_op.h
浏览文件 @
ae8ca764
...
@@ -20,11 +20,11 @@ limitations under the License. */
...
@@ -20,11 +20,11 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include "paddle/fluid/operators/squeeze_op.h"
#include "paddle/fluid/operators/squeeze_op.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
#endif
...
@@ -351,7 +351,7 @@ static void linalg_solve(const framework::ExecutionContext& context,
...
@@ -351,7 +351,7 @@ static void linalg_solve(const framework::ExecutionContext& context,
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
math
::
MatrixSolveFunctor
<
DeviceContext
,
T
>
mat_solve
;
phi
::
funcs
::
MatrixSolveFunctor
<
DeviceContext
,
T
>
mat_solve
;
// input y can be vector or matrix
// input y can be vector or matrix
// but need to be unsqueezed if y is a vector
// but need to be unsqueezed if y is a vector
...
@@ -425,67 +425,6 @@ static void linalg_solve(const framework::ExecutionContext& context,
...
@@ -425,67 +425,6 @@ static void linalg_solve(const framework::ExecutionContext& context,
}
}
}
}
// for TransposeNormal
static
std
::
vector
<
int
>
getNewAxis
(
const
int
b_rank
)
{
std
::
vector
<
int
>
axis_1
=
{
0
};
std
::
vector
<
int
>
axis_2
=
{
1
,
0
};
std
::
vector
<
int
>
axis_3
=
{
0
,
2
,
1
};
std
::
vector
<
int
>
axis_4
=
{
0
,
1
,
3
,
2
};
std
::
vector
<
int
>
axis_5
=
{
0
,
1
,
2
,
4
,
3
};
std
::
vector
<
int
>
axis_6
=
{
0
,
1
,
2
,
3
,
5
,
4
};
std
::
vector
<
int
>
axis_7
=
{
0
,
1
,
2
,
3
,
4
,
6
,
5
};
std
::
vector
<
int
>
axis_8
=
{
0
,
1
,
2
,
3
,
4
,
5
,
7
,
6
};
std
::
vector
<
int
>
axis_9
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
8
,
7
};
switch
(
b_rank
)
{
case
1
:
return
axis_1
;
break
;
case
2
:
return
axis_2
;
break
;
case
3
:
return
axis_3
;
break
;
case
4
:
return
axis_4
;
break
;
case
5
:
return
axis_5
;
break
;
case
6
:
return
axis_6
;
break
;
case
7
:
return
axis_7
;
break
;
case
8
:
return
axis_8
;
break
;
default:
return
axis_9
;
}
}
// for Resize
static
std
::
vector
<
int64_t
>
getNewDimsVec
(
const
DDim
&
b_dims
)
{
std
::
vector
<
int64_t
>
b_dims_vec
=
phi
::
vectorize
(
b_dims
);
int
size
=
b_dims_vec
.
size
();
if
(
size
>=
2
)
{
// swap the last 2 elements in b_dims_vec
int64_t
temp
=
b_dims_vec
[
size
-
1
];
b_dims_vec
[
size
-
1
]
=
b_dims_vec
[
size
-
2
];
b_dims_vec
[
size
-
2
]
=
temp
;
return
b_dims_vec
;
}
PADDLE_ENFORCE_NE
(
b_dims_vec
.
empty
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The size of tensor b must not be %d after getting new dims"
,
0
));
// if b_dims_vec.size() == 1, just retun original vec
return
b_dims_vec
;
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
SolveKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SolveKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -553,11 +492,11 @@ class SolveGradKernel : public framework::OpKernel<T> {
...
@@ -553,11 +492,11 @@ class SolveGradKernel : public framework::OpKernel<T> {
tmp_dy
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
tmp_dy
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
tmp_input
(
input
->
dtype
());
Tensor
tmp_input
(
input
->
dtype
());
const
auto
&
new_dims_vec
=
getNewDimsVec
(
input
->
dims
());
const
auto
&
new_dims_vec
=
phi
::
funcs
::
getNewDimsVec
(
input
->
dims
());
tmp_input
.
Resize
(
phi
::
make_ddim
(
new_dims_vec
));
tmp_input
.
Resize
(
phi
::
make_ddim
(
new_dims_vec
));
tmp_input
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
tmp_input
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
phi
::
funcs
::
TransposeNormal
<
DeviceContext
,
T
>
trans
;
phi
::
funcs
::
TransposeNormal
<
DeviceContext
,
T
>
trans
;
std
::
vector
<
int
>
new_axis
=
getNewAxis
(
input
->
dims
().
size
());
std
::
vector
<
int
>
new_axis
=
phi
::
funcs
::
getNewAxis
(
input
->
dims
().
size
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
trans
(
dev_ctx
,
*
input
,
&
tmp_input
,
new_axis
);
trans
(
dev_ctx
,
*
input
,
&
tmp_input
,
new_axis
);
...
...
paddle/phi/kernels/funcs/CMakeLists.txt
浏览文件 @
ae8ca764
...
@@ -14,3 +14,4 @@ math_library(matrix_inverse DEPS dense_tensor eigen3 blas)
...
@@ -14,3 +14,4 @@ math_library(matrix_inverse DEPS dense_tensor eigen3 blas)
math_library
(
pooling DEPS dense_tensor
)
math_library
(
pooling DEPS dense_tensor
)
math_library
(
segment_pooling
)
math_library
(
segment_pooling
)
math_library
(
sequence2batch
)
math_library
(
sequence2batch
)
math_library
(
matrix_solve DEPS dense_tensor eigen3 blas math_function
)
paddle/
fluid/operators/math
/matrix_solve.cc
→
paddle/
phi/kernels/funcs
/matrix_solve.cc
浏览文件 @
ae8ca764
/* Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -12,30 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,30 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "Eigen/Core"
namespace
phi
{
#include "Eigen/LU"
namespace
funcs
{
#include "paddle/phi/kernels/funcs/blas/blas.h"
template
<
typename
Context
,
typename
T
>
namespace
paddle
{
void
MatrixSolveFunctor
<
Context
,
T
>::
operator
()(
const
Context
&
dev_ctx
,
namespace
operators
{
const
DenseTensor
&
a
,
namespace
math
{
const
DenseTensor
&
b
,
DenseTensor
*
out
)
{
template
<
typename
T
>
compute_solve_eigen
<
Context
,
T
>
(
dev_ctx
,
a
,
b
,
out
);
class
MatrixSolveFunctor
<
phi
::
CPUContext
,
T
>
{
}
public:
void
operator
()(
const
phi
::
CPUContext
&
dev_ctx
,
template
class
MatrixSolveFunctor
<
CPUContext
,
float
>;
const
framework
::
Tensor
&
a
,
template
class
MatrixSolveFunctor
<
CPUContext
,
double
>;
const
framework
::
Tensor
&
b
,
framework
::
Tensor
*
out
)
{
}
// namespace funcs
compute_solve_eigen
<
phi
::
CPUContext
,
T
>
(
dev_ctx
,
a
,
b
,
out
);
}
// namespace phi
}
};
template
class
MatrixSolveFunctor
<
phi
::
CPUContext
,
float
>;
template
class
MatrixSolveFunctor
<
phi
::
CPUContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/funcs/matrix_solve.cu
0 → 100644
浏览文件 @
ae8ca764
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/funcs/matrix_solve.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
namespace
funcs
{
template
<
typename
Context
,
typename
T
>
void
MatrixSolveFunctor
<
Context
,
T
>::
operator
()(
const
Context
&
context
,
const
DenseTensor
&
a
,
const
DenseTensor
&
b
,
DenseTensor
*
out
)
{
#ifndef PADDLE_WITH_HIP
// solve the equation: Ax = B,
// use cuBlas cublas<S/D>getrfBatched funcion to performs the LU
// factorization of each matrix A,
// and then use cuBlas cublas<S/D>getriBatched function to solve the
// equation after LU factorization.
// ref:
// https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched
const
auto
&
a_dims
=
a
.
dims
();
const
int
a_rank
=
a_dims
.
size
();
int
n
=
a_dims
[
a_rank
-
1
];
int
lda
=
n
;
int
batch_size
=
a_rank
>
2
?
a
.
numel
()
/
(
n
*
n
)
:
1
;
const
auto
&
b_dims
=
b
.
dims
();
const
int
b_rank
=
b_dims
.
size
();
int
nrhs
=
b_dims
[
b_rank
-
1
];
int
ldb
=
b_dims
[
b_rank
-
2
];
// make sure the out dims is right
out
->
Resize
(
b_dims
);
context
.
template
Alloc
<
T
>(
out
);
// copy input A to a temporary tensor tmp_a,
// LU factorization, written back to original matrix A, so in the beginning,
// it's necessary to create a temporary tensor tmp_a.
DenseTensor
tmp_a
(
a
.
dtype
());
tmp_a
.
Resize
(
a
.
dims
());
context
.
template
Alloc
<
T
>(
&
tmp_a
);
paddle
::
framework
::
TensorCopy
(
a
,
context
.
GetPlace
(),
&
tmp_a
);
// copy input B to a temporary tensor tmp_b, and transpose tmp_b,
// because cuBlas assumes column-major while Paddle uses row-majar.
DenseTensor
tmp_b
(
b
.
type
());
const
auto
&
new_dims_vec
=
getNewDimsVec
(
b_dims
);
tmp_b
.
Resize
(
phi
::
make_ddim
(
new_dims_vec
));
context
.
template
Alloc
<
T
>(
&
tmp_b
);
phi
::
funcs
::
TransposeNormal
<
Context
,
T
>
trans
;
std
::
vector
<
int
>
new_axis
=
getNewAxis
(
b_rank
);
trans
(
context
,
b
,
&
tmp_b
,
new_axis
);
const
T
*
a_data_in_gpu
=
tmp_a
.
data
<
T
>
();
const
T
*
b_data_in_gpu
=
tmp_b
.
data
<
T
>
();
std
::
vector
<
const
T
*>
cpu_ptrs
(
batch_size
*
2
);
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
cpu_ptrs
[
i
]
=
a_data_in_gpu
+
i
*
n
*
n
;
cpu_ptrs
[
i
+
batch_size
]
=
b_data_in_gpu
+
i
*
n
*
nrhs
;
}
// Copy the addresses of A and tmp_b from host to device.
paddle
::
memory
::
allocation
::
AllocationPtr
tmp_gpu_ptrs_data
=
paddle
::
memory
::
Alloc
(
context
,
cpu_ptrs
.
size
()
*
sizeof
(
T
*
));
paddle
::
memory
::
Copy
(
context
.
GetPlace
(),
tmp_gpu_ptrs_data
->
ptr
(),
phi
::
CPUPlace
(),
static_cast
<
void
*>
(
cpu_ptrs
.
data
()),
cpu_ptrs
.
size
()
*
sizeof
(
T
*
),
context
.
stream
());
T
**
gpu_tmp_b_ptrs
=
reinterpret_cast
<
T
**>
(
tmp_gpu_ptrs_data
->
ptr
())
+
batch_size
;
// Allocate device memory for BatchedGETRF's info and pivots.
int
num_ints
=
n
<
32
?
batch_size
:
batch_size
*
(
n
+
1
);
paddle
::
memory
::
allocation
::
AllocationPtr
tmp_gpu_info_data
=
paddle
::
memory
::
Alloc
(
context
,
num_ints
*
sizeof
(
int
));
int
*
gpu_info_ptr
=
reinterpret_cast
<
int
*>
(
tmp_gpu_info_data
->
ptr
());
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
context
);
// only for singular checking
std
::
vector
<
int
>
info
;
info
.
resize
(
batch_size
);
int
*
gpu_pivot_ptr
=
reinterpret_cast
<
int
*>
(
tmp_gpu_info_data
->
ptr
())
+
batch_size
;
// This function performs the LU factorization of each matrix A by the
// equation A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas
.
BatchedGETRF
(
n
,
reinterpret_cast
<
T
**>
(
tmp_gpu_ptrs_data
->
ptr
()),
gpu_pivot_ptr
,
gpu_info_ptr
,
batch_size
);
// check whether BatchedGETRF is executed successfully or not
paddle
::
memory
::
Copy
(
phi
::
CPUPlace
(),
info
.
data
(),
context
.
GetPlace
(),
gpu_info_ptr
,
sizeof
(
int
)
*
batch_size
,
context
.
stream
());
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
PADDLE_ENFORCE_EQ
(
info
[
i
],
0
,
phi
::
errors
::
PreconditionNotMet
(
"For batch [%d]: U(%d, %d) is zero, singular U. "
"Please check the matrix value and change it to a "
"non-singular matrix"
,
i
,
info
[
i
],
info
[
i
]));
}
// hold the result code from BatchedGETRS
int
host_info
=
0
;
// to solve the equation after LU factorization
CBLAS_TRANSPOSE
transA
=
CblasTrans
;
blas
.
BatchedGETRS
(
transA
,
n
,
nrhs
,
reinterpret_cast
<
const
T
**>
(
tmp_gpu_ptrs_data
->
ptr
()),
lda
,
gpu_pivot_ptr
,
gpu_tmp_b_ptrs
,
ldb
,
&
host_info
,
batch_size
);
// check whether BatchedGETRS is executed successfully or not
PADDLE_ENFORCE_EQ
(
host_info
,
0
,
phi
::
errors
::
InvalidArgument
(
"The [%d]'th argument to cublas*getrsBatched had "
"an illegal value."
,
-
host_info
));
// transpose tmp_b to get the final result in row-major form.
phi
::
funcs
::
TransposeNormal
<
Context
,
T
>
trans2
;
trans2
(
context
,
tmp_b
,
out
,
new_axis
);
#else
compute_solve_eigen
<
Context
,
T
>
(
context
,
a
,
b
,
out
);
#endif
}
template
class
MatrixSolveFunctor
<
GPUContext
,
float
>;
template
class
MatrixSolveFunctor
<
GPUContext
,
double
>;
// TODO(wuweilong): remove these instantiations later
template
class
MatrixSolveFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
float
>;
template
class
MatrixSolveFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
double
>;
}
// namespace funcs
}
// namespace phi
paddle/
fluid/operators/math
/matrix_solve.h
→
paddle/
phi/kernels/funcs
/matrix_solve.h
浏览文件 @
ae8ca764
...
@@ -18,18 +18,79 @@ limitations under the License. */
...
@@ -18,18 +18,79 @@ limitations under the License. */
#include "Eigen/Core"
#include "Eigen/Core"
#include "Eigen/LU"
#include "Eigen/LU"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
namespace
paddle
{
namespace
operators
{
namespace
phi
{
namespace
math
{
namespace
funcs
{
template
<
typename
DeviceContext
,
typename
T
>
// for TransposeNormal
void
compute_solve_eigen
(
const
DeviceContext
&
context
,
static
std
::
vector
<
int
>
getNewAxis
(
const
int
b_rank
)
{
const
framework
::
Tensor
&
a
,
std
::
vector
<
int
>
axis_1
=
{
0
};
const
framework
::
Tensor
&
b
,
std
::
vector
<
int
>
axis_2
=
{
1
,
0
};
framework
::
Tensor
*
out
)
{
std
::
vector
<
int
>
axis_3
=
{
0
,
2
,
1
};
std
::
vector
<
int
>
axis_4
=
{
0
,
1
,
3
,
2
};
std
::
vector
<
int
>
axis_5
=
{
0
,
1
,
2
,
4
,
3
};
std
::
vector
<
int
>
axis_6
=
{
0
,
1
,
2
,
3
,
5
,
4
};
std
::
vector
<
int
>
axis_7
=
{
0
,
1
,
2
,
3
,
4
,
6
,
5
};
std
::
vector
<
int
>
axis_8
=
{
0
,
1
,
2
,
3
,
4
,
5
,
7
,
6
};
std
::
vector
<
int
>
axis_9
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
8
,
7
};
switch
(
b_rank
)
{
case
1
:
return
axis_1
;
break
;
case
2
:
return
axis_2
;
break
;
case
3
:
return
axis_3
;
break
;
case
4
:
return
axis_4
;
break
;
case
5
:
return
axis_5
;
break
;
case
6
:
return
axis_6
;
break
;
case
7
:
return
axis_7
;
break
;
case
8
:
return
axis_8
;
break
;
default:
return
axis_9
;
}
}
// for Resize
static
std
::
vector
<
int64_t
>
getNewDimsVec
(
const
DDim
&
b_dims
)
{
std
::
vector
<
int64_t
>
b_dims_vec
=
phi
::
vectorize
(
b_dims
);
int
size
=
b_dims_vec
.
size
();
if
(
size
>=
2
)
{
// swap the last 2 elements in b_dims_vec
int64_t
temp
=
b_dims_vec
[
size
-
1
];
b_dims_vec
[
size
-
1
]
=
b_dims_vec
[
size
-
2
];
b_dims_vec
[
size
-
2
]
=
temp
;
return
b_dims_vec
;
}
PADDLE_ENFORCE_NE
(
b_dims_vec
.
empty
(),
true
,
phi
::
errors
::
PreconditionNotMet
(
"The size of tensor b must not be %d after getting new dims"
,
0
));
// if b_dims_vec.size() == 1, just retun original vec
return
b_dims_vec
;
}
template
<
typename
Context
,
typename
T
>
void
compute_solve_eigen
(
const
Context
&
context
,
const
DenseTensor
&
a
,
const
DenseTensor
&
b
,
DenseTensor
*
out
)
{
using
Matrix
=
using
Matrix
=
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
;
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
;
using
EigenMatrixMap
=
Eigen
::
Map
<
Matrix
>
;
using
EigenMatrixMap
=
Eigen
::
Map
<
Matrix
>
;
...
@@ -51,7 +112,7 @@ void compute_solve_eigen(const DeviceContext& context,
...
@@ -51,7 +112,7 @@ void compute_solve_eigen(const DeviceContext& context,
const
T
*
b_ptr
=
b
.
data
<
T
>
();
const
T
*
b_ptr
=
b
.
data
<
T
>
();
out
->
Resize
(
b_mat_dims
);
// make sure the out dims is right
out
->
Resize
(
b_mat_dims
);
// make sure the out dims is right
T
*
out_ptr
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
()
);
T
*
out_ptr
=
context
.
template
Alloc
<
T
>(
out
);
if
(
a_batch_size
==
b_batch_size
)
{
if
(
a_batch_size
==
b_batch_size
)
{
for
(
int
i
=
0
;
i
<
a_batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
a_batch_size
;
++
i
)
{
ConstEigenMatrixMap
a_mat
(
a_ptr
+
i
*
n
*
n
,
n
,
n
);
ConstEigenMatrixMap
a_mat
(
a_ptr
+
i
*
n
*
n
,
n
,
n
);
...
@@ -63,13 +124,13 @@ void compute_solve_eigen(const DeviceContext& context,
...
@@ -63,13 +124,13 @@ void compute_solve_eigen(const DeviceContext& context,
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
min_abs_pivot
,
min_abs_pivot
,
static_cast
<
T
>
(
0
),
static_cast
<
T
>
(
0
),
p
latform
::
errors
::
InvalidArgument
(
"Input is not invertible."
));
p
hi
::
errors
::
InvalidArgument
(
"Input is not invertible."
));
out_mat
.
noalias
()
=
lu
.
solve
(
b_mat
);
out_mat
.
noalias
()
=
lu
.
solve
(
b_mat
);
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
a_batch_size
,
PADDLE_ENFORCE_EQ
(
a_batch_size
,
b_batch_size
,
b_batch_size
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"All input tensors must have the same rank."
));
"All input tensors must have the same rank."
));
}
}
}
}
...
@@ -114,22 +175,21 @@ void SolveLinearSystem(T* matrix_data,
...
@@ -114,22 +175,21 @@ void SolveLinearSystem(T* matrix_data,
lu_decomposition
.
matrixLU
().
diagonal
().
cwiseAbs
().
minCoeff
();
lu_decomposition
.
matrixLU
().
diagonal
().
cwiseAbs
().
minCoeff
();
PADDLE_ENFORCE_GT
(
min_abs_piv
,
PADDLE_ENFORCE_GT
(
min_abs_piv
,
Treal
(
0
),
Treal
(
0
),
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"Something's wrong with SolveLinearSystem. "
));
"Something's wrong with SolveLinearSystem. "
));
output
=
lu_decomposition
.
solve
(
input_rhs
);
output
=
lu_decomposition
.
solve
(
input_rhs
);
}
}
}
}
template
<
typename
Device
Context
,
typename
T
>
template
<
typename
Context
,
typename
T
>
class
MatrixSolveFunctor
{
class
MatrixSolveFunctor
{
public:
public:
void
operator
()(
const
Device
Context
&
context
,
void
operator
()(
const
Context
&
context
,
const
framework
::
Tensor
&
a
,
const
Dense
Tensor
&
a
,
const
framework
::
Tensor
&
b
,
const
Dense
Tensor
&
b
,
framework
::
Tensor
*
out
);
Dense
Tensor
*
out
);
};
};
}
// namespace math
}
// namespace funcs
}
// namespace operators
}
// namespace phi
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录