Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e24ca55e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e24ca55e
编写于
3月 11, 2022
作者:
zhouweiwei2014
提交者:
GitHub
3月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi]migrate cholesky_solve op to phi (#40387)
上级
dc773828
变更
21
显示空白变更内容
内联
并排
Showing
21 changed file
with
775 addition
and
733 deletion
+775
-733
paddle/fluid/operators/cholesky_solve_op.cc
paddle/fluid/operators/cholesky_solve_op.cc
+9
-59
paddle/fluid/operators/cholesky_solve_op.cu
paddle/fluid/operators/cholesky_solve_op.cu
+0
-136
paddle/fluid/operators/cholesky_solve_op.h
paddle/fluid/operators/cholesky_solve_op.h
+0
-252
paddle/fluid/operators/triangular_solve_op.h
paddle/fluid/operators/triangular_solve_op.h
+0
-40
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+0
-2
paddle/fluid/platform/dynload/lapack.cc
paddle/fluid/platform/dynload/lapack.cc
+0
-27
paddle/fluid/platform/dynload/lapack.h
paddle/fluid/platform/dynload/lapack.h
+0
-68
paddle/phi/backends/dynload/lapack.h
paddle/phi/backends/dynload/lapack.h
+2
-2
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+54
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+5
-0
paddle/phi/kernels/cholesky_solve_grad_kernel.h
paddle/phi/kernels/cholesky_solve_grad_kernel.h
+31
-0
paddle/phi/kernels/cholesky_solve_kernel.h
paddle/phi/kernels/cholesky_solve_kernel.h
+28
-0
paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc
paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc
+25
-0
paddle/phi/kernels/cpu/cholesky_solve_kernel.cc
paddle/phi/kernels/cpu/cholesky_solve_kernel.cc
+42
-0
paddle/phi/kernels/funcs/lapack/CMakeLists.txt
paddle/phi/kernels/funcs/lapack/CMakeLists.txt
+1
-1
paddle/phi/kernels/funcs/lapack/lapack_function.cc
paddle/phi/kernels/funcs/lapack/lapack_function.cc
+139
-146
paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu
paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu
+30
-0
paddle/phi/kernels/gpu/cholesky_solve_kernel.cu
paddle/phi/kernels/gpu/cholesky_solve_kernel.cu
+141
-0
paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h
paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h
+134
-0
paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h
paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h
+104
-0
paddle/phi/ops/compat/cholesky_solve_sig.cc
paddle/phi/ops/compat/cholesky_solve_sig.cc
+30
-0
未找到文件。
paddle/fluid/operators/cholesky_solve_op.cc
浏览文件 @
e24ca55e
...
...
@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -39,50 +40,6 @@ class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker {
class
CholeskySolveOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
OP_INOUT_CHECK
(
context
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"CholeskySolve"
);
OP_INOUT_CHECK
(
context
->
HasInput
(
"Y"
),
"Input"
,
"Y"
,
"CholeskySolve"
);
OP_INOUT_CHECK
(
context
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"CholeskySolve"
);
auto
u_dims
=
context
->
GetInputDim
(
"Y"
);
auto
b_dims
=
context
->
GetInputDim
(
"X"
);
int
u_rank
=
u_dims
.
size
();
int
b_rank
=
b_dims
.
size
();
PADDLE_ENFORCE_GE
(
u_rank
,
2
,
platform
::
errors
::
InvalidArgument
(
"the rank of input Y must greater or equal to 2"
));
PADDLE_ENFORCE_GE
(
b_rank
,
2
,
platform
::
errors
::
InvalidArgument
(
"the rank of input X must greater or equal to 2"
));
PADDLE_ENFORCE_EQ
(
u_dims
[
u_rank
-
1
],
u_dims
[
u_rank
-
2
],
platform
::
errors
::
InvalidArgument
(
"input Matrix Y should be square matrix,"
"But Got last shape of %ld x %ld"
,
u_dims
[
u_rank
-
1
],
u_dims
[
u_rank
-
2
]));
PADDLE_ENFORCE_EQ
(
b_dims
[
b_rank
-
2
],
u_dims
[
u_rank
-
2
],
platform
::
errors
::
InvalidArgument
(
"the first dim of input X must equal to the dim of input Y,"
"But Got %ld and %ld"
,
b_dims
[
b_rank
-
2
],
u_dims
[
u_rank
-
2
]));
std
::
vector
<
int64_t
>
u_dims_vec
=
phi
::
vectorize
(
u_dims
);
std
::
vector
<
int64_t
>
b_dims_vec
=
phi
::
vectorize
(
b_dims
);
std
::
vector
<
int64_t
>
u_dims_vec_cut
(
u_dims_vec
.
begin
(),
u_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
b_dims_vec_cut
(
b_dims_vec
.
begin
(),
b_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
expand_batch_portion
=
get_broadcast_batch_portion
(
u_dims_vec_cut
,
b_dims_vec_cut
);
std
::
vector
<
int64_t
>
b_broadcast_dims
({
expand_batch_portion
});
b_broadcast_dims
.
insert
(
b_broadcast_dims
.
end
(),
{
b_dims_vec
[
b_rank
-
2
],
b_dims_vec
[
b_rank
-
1
]});
// dim of 'Out' is the same with 'Y' after broadcast
context
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
b_broadcast_dims
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
...
...
@@ -151,22 +108,15 @@ class CholeskySolveGradOp : public framework::OperatorWithKernel {
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
cholesky_solve
,
CholeskySolveInferShapeFunctor
,
PD_INFER_META
(
phi
::
CholeskySolveInferMeta
));
REGISTER_OPERATOR
(
cholesky_solve
,
ops
::
CholeskySolveOp
,
ops
::
CholeskySolveOpMaker
,
ops
::
CholeskySolveOpVarTypeInference
,
ops
::
CholeskySolveOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
CholeskySolveOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
CholeskySolveOpGradMaker
<
paddle
::
imperative
::
OpBase
>
,
CholeskySolveInferShapeFunctor
);
REGISTER_OPERATOR
(
cholesky_solve_grad
,
ops
::
CholeskySolveGradOp
);
REGISTER_OP_CPU_KERNEL
(
cholesky_solve
,
ops
::
CholeskySolveKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CholeskySolveKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
cholesky_solve_grad
,
ops
::
CholeskySolveGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
CholeskySolveGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
// Complex<> is not supported because of TensorExpand, which used to boardcast
// input Tensor
paddle/fluid/operators/cholesky_solve_op.cu
已删除
100644 → 0
浏览文件 @
dc773828
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/cholesky_solve_op.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
CUDADeviceContext
=
paddle
::
platform
::
CUDADeviceContext
;
template
<
typename
T
>
void
cusolver_potrs
(
const
cusolverDnHandle_t
&
cusolverH
,
cublasFillMode_t
uplo
,
int
n
,
int
nrhs
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
ldb
,
int
*
devInfo
);
template
<
>
void
cusolver_potrs
<
float
>
(
const
cusolverDnHandle_t
&
cusolverH
,
cublasFillMode_t
uplo
,
int
n
,
int
nrhs
,
float
*
Adata
,
int
lda
,
float
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnSpotrs
(
cusolverH
,
uplo
,
n
,
nrhs
,
Adata
,
lda
,
Bdata
,
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
double
>
(
const
cusolverDnHandle_t
&
cusolverH
,
cublasFillMode_t
uplo
,
int
n
,
int
nrhs
,
double
*
Adata
,
int
lda
,
double
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnDpotrs
(
cusolverH
,
uplo
,
n
,
nrhs
,
Adata
,
lda
,
Bdata
,
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
platform
::
complex
<
float
>>
(
const
cusolverDnHandle_t
&
cusolverH
,
cublasFillMode_t
uplo
,
int
n
,
int
nrhs
,
platform
::
complex
<
float
>
*
Adata
,
int
lda
,
platform
::
complex
<
float
>
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnCpotrs
(
cusolverH
,
uplo
,
n
,
nrhs
,
reinterpret_cast
<
const
cuComplex
*>
(
Adata
),
lda
,
reinterpret_cast
<
cuComplex
*>
(
Bdata
),
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
platform
::
complex
<
double
>>
(
const
cusolverDnHandle_t
&
cusolverH
,
cublasFillMode_t
uplo
,
int
n
,
int
nrhs
,
platform
::
complex
<
double
>
*
Adata
,
int
lda
,
platform
::
complex
<
double
>
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
cusolverDnZpotrs
(
cusolverH
,
uplo
,
n
,
nrhs
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
Adata
),
lda
,
reinterpret_cast
<
cuDoubleComplex
*>
(
Bdata
),
ldb
,
devInfo
));
}
template
<
typename
T
>
class
CholeskySolveFunctor
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
bool
upper
,
int
n
,
int
nrhs
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
)
{
cublasFillMode_t
uplo
=
upper
?
CUBLAS_FILL_MODE_UPPER
:
CUBLAS_FILL_MODE_LOWER
;
/* step 1: get cusolver handle*/
auto
cusolverH
=
dev_ctx
.
cusolver_dn_handle
();
/* step 2: solve A0*X0 = B0 */
cusolver_potrs
<
T
>
(
cusolverH
,
uplo
,
n
,
nrhs
,
Adata
,
lda
,
Bdata
,
lda
,
devInfo
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaDeviceSynchronize
());
}
};
template
<
typename
T
>
class
MatrixReduceSumFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
Tensor
&
in
,
Tensor
*
out
,
const
framework
::
ExecutionContext
&
ctx
)
{
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const
std
::
vector
<
std
::
int64_t
>
in_dims
=
phi
::
vectorize
(
in
.
dims
());
auto
in_size
=
in_dims
.
size
();
const
std
::
vector
<
std
::
int64_t
>
out_dims
=
phi
::
vectorize
(
out
->
dims
());
auto
out_size
=
out_dims
.
size
();
std
::
vector
<
std
::
int64_t
>
out_bst_dims
(
in_size
);
std
::
fill
(
out_bst_dims
.
data
(),
out_bst_dims
.
data
()
+
in_size
-
out_size
,
1
);
std
::
copy
(
out_dims
.
data
(),
out_dims
.
data
()
+
out_size
,
out_bst_dims
.
data
()
+
in_size
-
out_size
);
std
::
vector
<
int
>
out_reduce_dims
;
for
(
size_t
idx
=
0
;
idx
<=
in_size
-
3
;
idx
++
)
{
if
(
in_dims
[
idx
]
!=
1
&&
out_bst_dims
[
idx
]
==
1
)
{
out_reduce_dims
.
push_back
(
idx
);
}
}
gpuStream_t
stream
=
ctx
.
cuda_device_context
().
stream
();
TensorReduceImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
ctx
.
cuda_device_context
(),
in
,
out
,
kps
::
IdentityFunctor
<
T
>
(),
out_reduce_dims
,
stream
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
cholesky_solve
,
ops
::
CholeskySolveKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CholeskySolveKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
cholesky_solve_grad
,
ops
::
CholeskySolveGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
CholeskySolveGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
#endif // not PADDLE_WITH_HIP
paddle/fluid/operators/cholesky_solve_op.h
已删除
100644 → 0
浏览文件 @
dc773828
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/solve_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
paddle
{
namespace
operators
{
// namespace operators
template
<
typename
DeviceContext
,
typename
T
>
class
CholeskySolveFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
dev_ctx
,
bool
upper
,
int
n
,
int
nrhs
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
);
};
template
<
typename
T
>
class
CholeskySolveFunctor
<
paddle
::
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
bool
upper
,
int
n
,
int
nrhs
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
)
{
char
uplo
=
upper
?
'U'
:
'L'
;
phi
::
funcs
::
lapackCholeskySolve
<
T
>
(
uplo
,
n
,
nrhs
,
Adata
,
lda
,
Bdata
,
lda
,
devInfo
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
void
cholesky_solve_fn
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
uin
,
const
framework
::
Tensor
&
bin
,
framework
::
Tensor
*
out
,
bool
upper
)
{
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
// framework::Tensor broadcast
std
::
vector
<
int64_t
>
u_bst_dims_vec
;
std
::
vector
<
int64_t
>
b_bst_dims_vec
;
std
::
tie
(
u_bst_dims_vec
,
b_bst_dims_vec
)
=
get_broadcast_dims
(
uin
,
bin
);
framework
::
Tensor
u_bst
(
uin
.
type
());
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
uin
,
&
u_bst
,
u_bst_dims_vec
);
framework
::
Tensor
b_bst
(
bin
.
type
());
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
bin
,
&
b_bst
,
b_bst_dims_vec
);
auto
&
phi_dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
);
// calculate u's conjugate for complex
framework
::
Tensor
u_conj
(
u_bst
.
type
());
platform
::
ForRange
<
DeviceContext
>
u_for_range
(
dev_ctx
,
u_bst
.
numel
());
phi
::
funcs
::
ConjFunctor
<
T
>
u_functor
(
u_bst
.
data
<
T
>
(),
u_bst
.
numel
(),
u_conj
.
mutable_data
<
T
>
(
u_bst
.
dims
(),
dev_ctx
.
GetPlace
()));
u_for_range
(
u_functor
);
u_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
phi_dev_ctx
,
u_conj
);
// calculate b's conjugate for complex
framework
::
Tensor
b_conj
(
b_bst
.
type
());
platform
::
ForRange
<
DeviceContext
>
b_for_range
(
dev_ctx
,
b_bst
.
numel
());
phi
::
funcs
::
ConjFunctor
<
T
>
b_functor
(
b_bst
.
data
<
T
>
(),
b_bst
.
numel
(),
b_conj
.
mutable_data
<
T
>
(
b_bst
.
dims
(),
dev_ctx
.
GetPlace
()));
b_for_range
(
b_functor
);
b_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
phi_dev_ctx
,
b_conj
);
auto
ut_data
=
u_conj
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
uindims
=
u_bst
.
dims
();
auto
bindims
=
b_bst
.
dims
();
int
uinrank
=
uindims
.
size
();
int
binrank
=
bindims
.
size
();
int
n
=
uindims
[
uinrank
-
2
];
int
nrhs
=
bindims
[
binrank
-
1
];
int
ldab
=
std
::
max
(
1
,
n
);
// framework::Tensor out_copy(b_conj.type());
// out_copy.Resize(b_conj.dims());
framework
::
TensorCopy
(
b_conj
,
dev_ctx
.
GetPlace
(),
out
);
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
info_dims
=
phi
::
slice_ddim
(
bindims
,
0
,
binrank
-
2
);
auto
batchsize
=
product
(
info_dims
);
framework
::
Tensor
tmp
;
std
::
vector
<
int
>
tmpdim
(
1
,
batchsize
);
tmp
.
Resize
(
phi
::
make_ddim
(
tmpdim
));
int
*
info
=
tmp
.
mutable_data
<
int
>
(
dev_ctx
.
GetPlace
());
CholeskySolveFunctor
<
DeviceContext
,
T
>
functor
;
for
(
int
b
=
0
;
b
<
batchsize
;
b
++
)
{
auto
uin_data_item
=
&
ut_data
[
b
*
n
*
n
];
auto
out_data_item
=
&
out_data
[
b
*
n
*
nrhs
];
auto
info_item
=
&
info
[
b
];
functor
(
dev_ctx
,
upper
,
n
,
nrhs
,
uin_data_item
,
ldab
,
out_data_item
,
info_item
);
}
// calculate out's conjugate for complex
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
phi
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out
->
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
*
out
=
phi
::
TransposeLast2Dim
<
T
>
(
phi_dev_ctx
,
*
out
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
CholeskySolveKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
paddle
::
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
uin
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
bin
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
upper
=
ctx
.
Attr
<
bool
>
(
"upper"
);
cholesky_solve_fn
<
DeviceContext
,
T
>
(
ctx
,
*
uin
,
*
bin
,
out
,
upper
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
CholeskySolveGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
bin
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
uin
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Out"
);
auto
*
dout
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
db
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
du
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
upper
=
ctx
.
Attr
<
bool
>
(
"upper"
);
const
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
phi_dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
);
std
::
vector
<
int64_t
>
u_bst_dims_vec
;
std
::
vector
<
int64_t
>
b_bst_dims_vec
;
std
::
tie
(
u_bst_dims_vec
,
b_bst_dims_vec
)
=
get_broadcast_dims
(
*
uin
,
*
bin
);
framework
::
Tensor
u_bst
(
uin
->
type
());
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
*
uin
,
&
u_bst
,
u_bst_dims_vec
);
framework
::
Tensor
db_bst
(
bin
->
type
());
TensorExpand
<
T
,
DeviceContext
>
(
dev_ctx
,
*
bin
,
&
db_bst
,
b_bst_dims_vec
);
if
(
dout
)
{
db
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
cholesky_solve_fn
<
DeviceContext
,
T
>
(
ctx
,
u_bst
,
*
dout
,
&
db_bst
,
upper
);
if
(
db_bst
.
dims
()
==
db
->
dims
())
{
framework
::
TensorCopy
(
db_bst
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
db
);
}
else
{
MatrixReduceSumFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
db_bst
,
db
,
ctx
);
db
->
Resize
(
bin
->
dims
());
}
auto
blas
=
phi
::
funcs
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
// calculate out's conjugate for complex
framework
::
Tensor
out_conj
(
out
->
type
());
platform
::
ForRange
<
DeviceContext
>
out_for_range
(
dev_ctx
,
out
->
numel
());
phi
::
funcs
::
ConjFunctor
<
T
>
out_functor
(
out
->
data
<
T
>
(),
out
->
numel
(),
out_conj
.
mutable_data
<
T
>
(
out
->
dims
(),
dev_ctx
.
GetPlace
()));
out_for_range
(
out_functor
);
out_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
phi_dev_ctx
,
out_conj
);
framework
::
Tensor
commonterm
(
out
->
type
());
auto
outdims
=
out_conj
.
dims
();
auto
dbdims
=
db_bst
.
dims
();
auto
mat_dim_a
=
phi
::
funcs
::
CreateMatrixDescriptor
(
outdims
,
0
,
false
);
auto
mat_dim_b
=
phi
::
funcs
::
CreateMatrixDescriptor
(
dbdims
,
0
,
false
);
auto
cmtdim
=
outdims
;
cmtdim
[
cmtdim
.
size
()
-
2
]
=
dbdims
[
dbdims
.
size
()
-
2
];
commonterm
.
Resize
(
cmtdim
);
commonterm
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
blas
.
MatMul
(
db_bst
,
mat_dim_b
,
out_conj
,
mat_dim_a
,
static_cast
<
T
>
(
1
),
&
commonterm
,
static_cast
<
T
>
(
0
));
// calculate commonterm's conjugate for complex
framework
::
Tensor
commonterm_conj
(
commonterm
.
type
());
platform
::
ForRange
<
DeviceContext
>
commonterm_for_range
(
dev_ctx
,
commonterm
.
numel
());
phi
::
funcs
::
ConjFunctor
<
T
>
commonterm_functor
(
commonterm
.
data
<
T
>
(),
commonterm
.
numel
(),
commonterm_conj
.
mutable_data
<
T
>
(
commonterm
.
dims
(),
dev_ctx
.
GetPlace
()));
commonterm_for_range
(
commonterm_functor
);
commonterm_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
phi_dev_ctx
,
commonterm_conj
);
phi
::
AddRawKernel
<
T
>
(
static_cast
<
const
typename
paddle
::
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
commonterm
,
commonterm_conj
,
-
1
,
&
commonterm
);
auto
mat_dim_u
=
phi
::
funcs
::
CreateMatrixDescriptor
(
u_bst
.
dims
(),
0
,
false
);
auto
mat_dim_c
=
phi
::
funcs
::
CreateMatrixDescriptor
(
commonterm
.
dims
(),
0
,
false
);
Tensor
du_bst
(
uin
->
type
());
// get upper or lower triangular
du_bst
.
Resize
(
u_bst
.
dims
());
du_bst
.
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
upper
)
{
blas
.
MatMul
(
u_bst
,
mat_dim_u
,
commonterm
,
mat_dim_c
,
static_cast
<
T
>
(
-
1
),
&
du_bst
,
static_cast
<
T
>
(
0
));
}
else
{
blas
.
MatMul
(
commonterm
,
mat_dim_c
,
u_bst
,
mat_dim_u
,
static_cast
<
T
>
(
-
1
),
&
du_bst
,
static_cast
<
T
>
(
0
));
}
const
auto
&
udims
=
u_bst
.
dims
();
const
auto
H
=
udims
[
udims
.
size
()
-
2
];
const
auto
W
=
udims
[
udims
.
size
()
-
1
];
platform
::
ForRange
<
DeviceContext
>
x_for_range
(
dev_ctx
,
u_bst
.
numel
());
TrilTriuCompute
<
T
>
tril_triu_computer
(
du_bst
.
data
<
T
>
(),
0
,
!
upper
,
H
,
W
,
u_bst
.
data
<
T
>
());
x_for_range
(
tril_triu_computer
);
du
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
u_bst
.
dims
()
==
du
->
dims
())
{
framework
::
TensorCopy
(
u_bst
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
du
);
}
else
{
MatrixReduceSumFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
u_bst
,
du
,
ctx
);
du
->
Resize
(
uin
->
dims
());
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/triangular_solve_op.h
浏览文件 @
e24ca55e
...
...
@@ -60,45 +60,5 @@ static void triangular_solve(const DeviceContext &context, const Tensor &x,
unitriangular
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
MatrixReduceSumFunctor
{
public:
void
operator
()(
const
Tensor
&
input
,
Tensor
*
output
,
const
framework
::
ExecutionContext
&
ctx
);
};
template
<
typename
T
>
class
MatrixReduceSumFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
Tensor
&
in
,
Tensor
*
out
,
const
framework
::
ExecutionContext
&
ctx
)
{
// For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3]
// out_reduce_dim should be [0, 2]
const
std
::
vector
<
std
::
int64_t
>
in_dims
=
phi
::
vectorize
(
in
.
dims
());
auto
in_size
=
in_dims
.
size
();
const
std
::
vector
<
std
::
int64_t
>
out_dims
=
phi
::
vectorize
(
out
->
dims
());
auto
out_size
=
out_dims
.
size
();
std
::
vector
<
std
::
int64_t
>
out_bst_dims
(
in_size
);
std
::
fill
(
out_bst_dims
.
data
(),
out_bst_dims
.
data
()
+
in_size
-
out_size
,
1
);
std
::
copy
(
out_dims
.
data
(),
out_dims
.
data
()
+
out_size
,
out_bst_dims
.
data
()
+
in_size
-
out_size
);
out
->
Resize
(
phi
::
make_ddim
(
out_bst_dims
));
std
::
vector
<
int
>
out_reduce_dims
;
for
(
size_t
idx
=
0
;
idx
<=
in_size
-
3
;
idx
++
)
{
if
(
in_dims
[
idx
]
!=
1
&&
out_bst_dims
[
idx
]
==
1
)
{
out_reduce_dims
.
push_back
(
idx
);
}
}
ReduceKernelFunctor
<
platform
::
CPUDeviceContext
,
T
,
SumFunctor
>
(
&
in
,
out
,
out_reduce_dims
,
true
,
false
,
ctx
)
.
template
apply
<
T
>();
out
->
Resize
(
phi
::
make_ddim
(
out_dims
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
e24ca55e
...
...
@@ -46,8 +46,6 @@ if (WITH_MKLML)
cc_library
(
dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml phi_dynload_mklml
)
endif
()
cc_library
(
dynload_lapack SRCS lapack.cc DEPS dynamic_loader phi_dynload_lapack
)
add_dependencies
(
dynload_lapack extern_lapack
)
# TODO(TJ): add iomp, mkldnn?
if
(
MKL_FOUND AND WITH_ONEMKL
)
...
...
paddle/fluid/platform/dynload/lapack.cc
已删除
100644 → 0
浏览文件 @
dc773828
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/lapack.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
#define DEFINE_WRAP(__name) DynLoad__##__name __name
LAPACK_ROUTINE_EACH
(
DEFINE_WRAP
);
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/lapack.h
已删除
100644 → 0
浏览文件 @
dc773828
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <complex>
#include <mutex>
#include "paddle/phi/backends/dynload/lapack.h"
#include "paddle/phi/common/complex.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
/**
* The following macro definition can generate structs
* (for each function) to dynamic load lapack routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \
using DynLoad__##__name = phi::dynload::DynLoad__##__name; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_LAPACK_WRAP(__name) \
DYNAMIC_LOAD_LAPACK_WRAP(__name)
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \
__macro(sgetrf_); \
__macro(zheevd_); \
__macro(cheevd_); \
__macro(dsyevd_); \
__macro(ssyevd_); \
__macro(dgeev_); \
__macro(sgeev_); \
__macro(zgeev_); \
__macro(cgeev_); \
__macro(dgels_); \
__macro(sgels_); \
__macro(dgelsd_); \
__macro(sgelsd_); \
__macro(dgelsy_); \
__macro(sgelsy_); \
__macro(dgelss_); \
__macro(sgelss_); \
__macro(zpotrs_); \
__macro(cpotrs_); \
__macro(dpotrs_); \
__macro(spotrs_);
LAPACK_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_LAPACK_WRAP
);
#undef DYNAMIC_LOAD_LAPACK_WRAP
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/phi/backends/dynload/lapack.h
浏览文件 @
e24ca55e
...
...
@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"
//
Note(zhouwei): because lapack doesn't provide appropriate header file.
// should expose API statement yourself.
//
Because lapack doesn't provide appropriate header file,
//
we
should expose API statement yourself.
// getrf_(For example)
extern
"C"
void
dgetrf_
(
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
e24ca55e
...
...
@@ -274,6 +274,60 @@ void HuberLossInferMeta(const MetaTensor& input,
out
->
share_lod
(
input
);
}
void
CholeskySolveInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
upper
,
MetaTensor
*
out
)
{
auto
x_dims
=
x
.
dims
();
auto
y_dims
=
y
.
dims
();
auto
x_dims_n
=
x_dims
.
size
();
auto
y_dims_n
=
y_dims
.
size
();
PADDLE_ENFORCE_GE
(
x_dims_n
,
2
,
phi
::
errors
::
InvalidArgument
(
"the rank of input Y must greater or equal to 2"
));
PADDLE_ENFORCE_GE
(
y_dims_n
,
2
,
phi
::
errors
::
InvalidArgument
(
"the rank of input X must greater or equal to 2"
));
PADDLE_ENFORCE_EQ
(
y_dims
[
y_dims_n
-
1
],
y_dims
[
y_dims_n
-
2
],
phi
::
errors
::
InvalidArgument
(
"input Matrix Y should be square matrix,"
"But Got last shape of %ld x %ld"
,
y_dims
[
y_dims_n
-
1
],
y_dims
[
y_dims_n
-
2
]));
PADDLE_ENFORCE_EQ
(
x_dims
[
x_dims_n
-
2
],
y_dims
[
y_dims_n
-
2
],
phi
::
errors
::
InvalidArgument
(
"the first dim of Matrix X must be equal to "
"the fisrt dim of Matrix Y,"
"But Got %ld and %ld"
,
x_dims
[
x_dims_n
-
2
],
y_dims
[
y_dims_n
-
2
]));
std
::
vector
<
int64_t
>
x_dims_vec
=
phi
::
vectorize
(
x_dims
);
std
::
vector
<
int64_t
>
y_dims_vec
=
phi
::
vectorize
(
y_dims
);
std
::
vector
<
int64_t
>
x_dims_vec_cut
(
x_dims_vec
.
begin
(),
x_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
y_dims_vec_cut
(
y_dims_vec
.
begin
(),
y_dims_vec
.
end
()
-
2
);
std
::
vector
<
int64_t
>
expand_batch_portion
=
funcs
::
MatrixGetBroadcastBatchPortion
(
x_dims_vec_cut
,
y_dims_vec_cut
);
std
::
vector
<
int64_t
>
x_broadcast_dims
({
expand_batch_portion
});
x_broadcast_dims
.
insert
(
x_broadcast_dims
.
end
(),
{
x_dims_vec
[
x_dims_n
-
2
],
x_dims_vec
[
x_dims_n
-
1
]});
// dim of 'out' is the same with 'X' after broadcast
out
->
set_dims
(
phi
::
make_ddim
(
x_broadcast_dims
));
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
out
->
share_lod
(
x
);
}
void
TriangularSolveInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
upper
,
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
e24ca55e
...
...
@@ -62,6 +62,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta,
MetaTensor
*
residual
,
MetaConfig
config
=
MetaConfig
());
void
CholeskySolveInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
upper
,
MetaTensor
*
out
);
void
TriangularSolveInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
bool
upper
,
...
...
paddle/phi/kernels/cholesky_solve_grad_kernel.h
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CholeskySolveGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
bool
upper
,
DenseTensor
*
dx
,
DenseTensor
*
dy
);
}
// namespace phi
paddle/phi/kernels/cholesky_solve_kernel.h
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CholeskySolveKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
bool
upper
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
cholesky_solve_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
CholeskySolveGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/cpu/cholesky_solve_kernel.cc
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace
phi
{
template
<
typename
T
>
class
CholeskySolveFunctor
<
T
,
CPUContext
>
{
public:
void
operator
()(
const
CPUContext
&
dev_ctx
,
bool
upper
,
int
M
,
int
N
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
)
{
char
uplo
=
upper
?
'U'
:
'L'
;
funcs
::
lapackCholeskySolve
<
T
>
(
uplo
,
M
,
N
,
Adata
,
lda
,
Bdata
,
lda
,
devInfo
);
}
};
}
// namespace phi
PD_REGISTER_KERNEL
(
cholesky_solve
,
CPU
,
ALL_LAYOUT
,
phi
::
CholeskySolveKernel
,
float
,
double
)
{}
paddle/phi/kernels/funcs/lapack/CMakeLists.txt
浏览文件 @
e24ca55e
math_library
(
lapack_function DEPS dynload_lapack
)
math_library
(
lapack_function DEPS
phi_
dynload_lapack
)
paddle/phi/kernels/funcs/lapack/lapack_function.cc
浏览文件 @
e24ca55e
...
...
@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/
fluid/platform
/dynload/lapack.h"
#include "paddle/
phi/backends
/dynload/lapack.h"
#include "paddle/phi/common/complex.h"
namespace
phi
{
...
...
@@ -22,12 +22,12 @@ namespace funcs {
// LU (for example)
template
<
>
void
lapackLu
<
double
>
(
int
m
,
int
n
,
double
*
a
,
int
lda
,
int
*
ipiv
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
dynload
::
dgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
}
template
<
>
void
lapackLu
<
float
>
(
int
m
,
int
n
,
float
*
a
,
int
lda
,
int
*
ipiv
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
sgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
dynload
::
sgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
}
// eigh
...
...
@@ -47,7 +47,7 @@ void lapackEigh<float>(char jobz,
int
*
info
)
{
(
void
)
rwork
;
// unused
(
void
)
lrwork
;
// unused
paddle
::
platform
::
dynload
::
ssyevd_
(
dynload
::
ssyevd_
(
&
jobz
,
&
uplo
,
&
n
,
a
,
&
lda
,
w
,
work
,
&
lwork
,
iwork
,
&
liwork
,
info
);
}
...
...
@@ -67,7 +67,7 @@ void lapackEigh<double>(char jobz,
int
*
info
)
{
(
void
)
rwork
;
// unused
(
void
)
lrwork
;
// unused
paddle
::
platform
::
dynload
::
dsyevd_
(
dynload
::
dsyevd_
(
&
jobz
,
&
uplo
,
&
n
,
a
,
&
lda
,
w
,
work
,
&
lwork
,
iwork
,
&
liwork
,
info
);
}
...
...
@@ -86,8 +86,7 @@ void lapackEigh<phi::dtype::complex<float>, float>(
int
*
iwork
,
int
liwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
cheevd_
(
&
jobz
,
dynload
::
cheevd_
(
&
jobz
,
&
uplo
,
&
n
,
reinterpret_cast
<
std
::
complex
<
float
>
*>
(
a
),
...
...
@@ -117,8 +116,7 @@ void lapackEigh<phi::dtype::complex<double>, double>(
int
*
iwork
,
int
liwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
zheevd_
(
&
jobz
,
dynload
::
zheevd_
(
&
jobz
,
&
uplo
,
&
n
,
reinterpret_cast
<
std
::
complex
<
double
>
*>
(
a
),
...
...
@@ -152,7 +150,7 @@ void lapackEig<double>(char jobvl,
double
*
wr
=
w
;
double
*
wi
=
w
+
n
;
(
void
)
rwork
;
// unused
paddle
::
platform
::
dynload
::
dgeev_
(
&
jobvl
,
dynload
::
dgeev_
(
&
jobvl
,
&
jobvr
,
&
n
,
a
,
...
...
@@ -186,7 +184,7 @@ void lapackEig<float>(char jobvl,
float
*
wr
=
w
;
float
*
wi
=
w
+
n
;
(
void
)
rwork
;
// unused
paddle
::
platform
::
dynload
::
sgeev_
(
&
jobvl
,
dynload
::
sgeev_
(
&
jobvl
,
&
jobvr
,
&
n
,
a
,
...
...
@@ -218,8 +216,7 @@ void lapackEig<phi::dtype::complex<double>, double>(
int
lwork
,
double
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
zgeev_
(
&
jobvl
,
dynload
::
zgeev_
(
&
jobvl
,
&
jobvr
,
&
n
,
reinterpret_cast
<
std
::
complex
<
double
>
*>
(
a
),
...
...
@@ -251,8 +248,7 @@ void lapackEig<phi::dtype::complex<float>, float>(
int
lwork
,
float
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
cgeev_
(
&
jobvl
,
dynload
::
cgeev_
(
&
jobvl
,
&
jobvr
,
&
n
,
reinterpret_cast
<
std
::
complex
<
float
>
*>
(
a
),
...
...
@@ -280,8 +276,7 @@ void lapackGels<double>(char trans,
double
*
work
,
int
lwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dgels_
(
&
trans
,
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
work
,
&
lwork
,
info
);
dynload
::
dgels_
(
&
trans
,
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
work
,
&
lwork
,
info
);
}
template
<
>
...
...
@@ -296,8 +291,7 @@ void lapackGels<float>(char trans,
float
*
work
,
int
lwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
sgels_
(
&
trans
,
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
work
,
&
lwork
,
info
);
dynload
::
sgels_
(
&
trans
,
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
work
,
&
lwork
,
info
);
}
template
<
>
...
...
@@ -316,7 +310,7 @@ void lapackGelsd<double>(int m,
double
*
rwork
,
int
*
iwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dgelsd_
(
&
m
,
dynload
::
dgelsd_
(
&
m
,
&
n
,
&
nrhs
,
a
,
...
...
@@ -348,7 +342,7 @@ void lapackGelsd<float>(int m,
float
*
rwork
,
int
*
iwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
sgelsd_
(
&
m
,
dynload
::
sgelsd_
(
&
m
,
&
n
,
&
nrhs
,
a
,
...
...
@@ -379,7 +373,7 @@ void lapackGelsy<double>(int m,
int
lwork
,
double
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dgelsy_
(
dynload
::
dgelsy_
(
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
jpvt
,
&
rcond
,
rank
,
work
,
&
lwork
,
info
);
}
...
...
@@ -398,7 +392,7 @@ void lapackGelsy<float>(int m,
int
lwork
,
float
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
sgelsy_
(
dynload
::
sgelsy_
(
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
jpvt
,
&
rcond
,
rank
,
work
,
&
lwork
,
info
);
}
...
...
@@ -417,7 +411,7 @@ void lapackGelss<double>(int m,
int
lwork
,
double
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dgelss_
(
dynload
::
dgelss_
(
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
s
,
&
rcond
,
rank
,
work
,
&
lwork
,
info
);
}
...
...
@@ -436,7 +430,7 @@ void lapackGelss<float>(int m,
int
lwork
,
float
*
rwork
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
sgelss_
(
dynload
::
sgelss_
(
&
m
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
s
,
&
rcond
,
rank
,
work
,
&
lwork
,
info
);
}
...
...
@@ -450,8 +444,7 @@ void lapackCholeskySolve<phi::dtype::complex<double>>(
phi
::
dtype
::
complex
<
double
>
*
b
,
int
ldb
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
zpotrs_
(
&
uplo
,
dynload
::
zpotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
reinterpret_cast
<
std
::
complex
<
double
>
*>
(
a
),
...
...
@@ -471,7 +464,7 @@ void lapackCholeskySolve<phi::dtype::complex<float>>(
phi
::
dtype
::
complex
<
float
>
*
b
,
int
ldb
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
cpotrs_
(
&
uplo
,
dynload
::
cpotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
reinterpret_cast
<
std
::
complex
<
float
>
*>
(
a
),
...
...
@@ -490,7 +483,7 @@ void lapackCholeskySolve<double>(char uplo,
double
*
b
,
int
ldb
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
dpotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
info
);
dynload
::
dpotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
info
);
}
template
<
>
...
...
@@ -502,7 +495,7 @@ void lapackCholeskySolve<float>(char uplo,
float
*
b
,
int
ldb
,
int
*
info
)
{
paddle
::
platform
::
dynload
::
spotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
info
);
dynload
::
spotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
info
);
}
}
// namespace funcs
...
...
paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// backward reuse forward, HIP not support forward
#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
cholesky_solve_grad
,
// cuda_only
GPU
,
ALL_LAYOUT
,
phi
::
CholeskySolveGradKernel
,
float
,
double
)
{}
#endif // not PADDLE_WITH_HIP
paddle/phi/kernels/gpu/cholesky_solve_kernel.cu
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace
phi
{
template
<
typename
T
>
void
cusolver_potrs
(
const
solverHandle_t
&
handle
,
cublasFillMode_t
uplo
,
int
M
,
int
N
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
ldb
,
int
*
devInfo
);
template
<
>
void
cusolver_potrs
<
float
>
(
const
solverHandle_t
&
handle
,
cublasFillMode_t
uplo
,
int
M
,
int
N
,
float
*
Adata
,
int
lda
,
float
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cusolverDnSpotrs
(
handle
,
uplo
,
M
,
N
,
Adata
,
lda
,
Bdata
,
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
double
>
(
const
solverHandle_t
&
handle
,
cublasFillMode_t
uplo
,
int
M
,
int
N
,
double
*
Adata
,
int
lda
,
double
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cusolverDnDpotrs
(
handle
,
uplo
,
M
,
N
,
Adata
,
lda
,
Bdata
,
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
phi
::
dtype
::
complex
<
float
>>
(
const
solverHandle_t
&
handle
,
cublasFillMode_t
uplo
,
int
M
,
int
N
,
phi
::
dtype
::
complex
<
float
>
*
Adata
,
int
lda
,
phi
::
dtype
::
complex
<
float
>
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cusolverDnCpotrs
(
handle
,
uplo
,
M
,
N
,
reinterpret_cast
<
const
cuComplex
*>
(
Adata
),
lda
,
reinterpret_cast
<
cuComplex
*>
(
Bdata
),
ldb
,
devInfo
));
}
template
<
>
void
cusolver_potrs
<
phi
::
dtype
::
complex
<
double
>>
(
const
cusolverDnHandle_t
&
handle
,
cublasFillMode_t
uplo
,
int
M
,
int
N
,
phi
::
dtype
::
complex
<
double
>
*
Adata
,
int
lda
,
phi
::
dtype
::
complex
<
double
>
*
Bdata
,
int
ldb
,
int
*
devInfo
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
cusolverDnZpotrs
(
handle
,
uplo
,
M
,
N
,
reinterpret_cast
<
const
cuDoubleComplex
*>
(
Adata
),
lda
,
reinterpret_cast
<
cuDoubleComplex
*>
(
Bdata
),
ldb
,
devInfo
));
}
template
<
typename
T
>
class
CholeskySolveFunctor
<
T
,
GPUContext
>
{
public:
void
operator
()(
const
GPUContext
&
dev_ctx
,
bool
upper
,
int
M
,
int
N
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
)
{
cublasFillMode_t
uplo
=
upper
?
CUBLAS_FILL_MODE_UPPER
:
CUBLAS_FILL_MODE_LOWER
;
auto
handle
=
dev_ctx
.
cusolver_dn_handle
();
cusolver_potrs
<
T
>
(
handle
,
uplo
,
M
,
N
,
Adata
,
lda
,
Bdata
,
lda
,
devInfo
);
}
};
}
// namespace phi
PD_REGISTER_KERNEL
(
cholesky_solve
,
// cuda_only
GPU
,
ALL_LAYOUT
,
phi
::
CholeskySolveKernel
,
float
,
double
)
{}
#endif // not PADDLE_WITH_HIP
paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/cholesky_solve_grad_kernel.h"
#include "paddle/phi/kernels/cholesky_solve_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/matrix_reduce.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/tril_triu_op.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
CholeskySolveGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
out
,
const
DenseTensor
&
dout
,
bool
upper
,
DenseTensor
*
dx
,
DenseTensor
*
dy
)
{
// get broadcast dim
std
::
vector
<
int64_t
>
x_bst_dims_vec
;
std
::
vector
<
int64_t
>
y_bst_dims_vec
;
std
::
tie
(
x_bst_dims_vec
,
y_bst_dims_vec
)
=
funcs
::
MatrixGetBroadcastDims
(
x
,
y
);
ScalarArray
x_bst_dims
(
x_bst_dims_vec
);
ScalarArray
y_bst_dims
(
y_bst_dims_vec
);
// Tensor broadcast to temp 'y_bst'
DenseTensor
y_bst
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
y_bst_dims
);
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
y
,
y_bst_dims
,
&
y_bst
);
// reuse forward to calculate dx_bst, which is broad_cast of dx
DenseTensor
dx_bst
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
x_bst_dims
);
CholeskySolveKernel
<
T
,
Context
>
(
dev_ctx
,
dout
,
y_bst
,
upper
,
&
dx_bst
);
// get 'dx' according to 'dx_bst'
dx
->
Resize
(
x
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dx
);
if
(
dx_bst
.
dims
()
==
x
.
dims
())
{
Copy
<
Context
>
(
dev_ctx
,
dx_bst
,
dev_ctx
.
GetPlace
(),
false
,
dx
);
}
else
{
funcs
::
MatrixReduceSumFunctor
<
T
,
Context
>
functor
;
functor
(
dev_ctx
,
dx_bst
,
dx
);
dx
->
Resize
(
x
.
dims
());
}
// calculate out's conjugate for complex
DenseTensor
out_conj
=
Conj
<
T
,
Context
>
(
dev_ctx
,
out
);
out_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
out_conj
);
DenseTensor
commonterm
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
y_bst_dims
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
Context
,
T
>
(
dev_ctx
);
blas
.
MatMul
(
dx_bst
,
phi
::
funcs
::
CreateMatrixDescriptor
(
dx_bst
.
dims
(),
0
,
false
),
out_conj
,
phi
::
funcs
::
CreateMatrixDescriptor
(
out_conj
.
dims
(),
0
,
false
),
static_cast
<
T
>
(
1
),
&
commonterm
,
static_cast
<
T
>
(
0
));
// calculate commonterm's conjugate for complex
DenseTensor
commonterm_conj
=
Conj
<
T
,
Context
>
(
dev_ctx
,
commonterm
);
commonterm_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
commonterm_conj
);
phi
::
AddRawKernel
<
T
>
(
dev_ctx
,
commonterm
,
commonterm_conj
,
-
1
,
&
commonterm
);
DenseTensor
dy_bst
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
y_bst_dims
);
if
(
upper
)
{
blas
.
MatMul
(
y_bst
,
phi
::
funcs
::
CreateMatrixDescriptor
(
y_bst
.
dims
(),
0
,
false
),
commonterm
,
phi
::
funcs
::
CreateMatrixDescriptor
(
commonterm
.
dims
(),
0
,
false
),
static_cast
<
T
>
(
-
1
),
&
dy_bst
,
static_cast
<
T
>
(
0
));
}
else
{
blas
.
MatMul
(
commonterm
,
phi
::
funcs
::
CreateMatrixDescriptor
(
commonterm
.
dims
(),
0
,
false
),
y_bst
,
phi
::
funcs
::
CreateMatrixDescriptor
(
y_bst
.
dims
(),
0
,
false
),
static_cast
<
T
>
(
-
1
),
&
dy_bst
,
static_cast
<
T
>
(
0
));
}
// get upper or lower of 'dy_bst'
DenseTensor
dy_bst_upper
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
y_bst_dims
);
int
y_bst_ndim
=
y_bst_dims_vec
.
size
();
const
auto
H
=
y_bst_dims_vec
[
y_bst_ndim
-
2
];
const
auto
W
=
y_bst_dims_vec
[
y_bst_ndim
-
1
];
phi
::
funcs
::
ForRange
<
Context
>
y_for_range
(
dev_ctx
,
dy_bst
.
numel
());
paddle
::
operators
::
TrilTriuCompute
<
T
>
tril_triu_functor
(
dy_bst
.
data
<
T
>
(),
0
,
!
upper
,
H
,
W
,
dy_bst_upper
.
data
<
T
>
());
y_for_range
(
tril_triu_functor
);
// get 'dy' according to 'dy_bst'
dy
->
Resize
(
y
.
dims
());
dev_ctx
.
template
Alloc
<
T
>(
dy
);
if
(
dy_bst_upper
.
dims
()
==
y
.
dims
())
{
Copy
<
Context
>
(
dev_ctx
,
dy_bst_upper
,
dev_ctx
.
GetPlace
(),
false
,
dy
);
}
else
{
funcs
::
MatrixReduceSumFunctor
<
T
,
Context
>
functor
;
functor
(
dev_ctx
,
dy_bst_upper
,
dy
);
dy
->
Resize
(
y
.
dims
());
}
}
}
// namespace phi
paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/cholesky_solve_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/expand_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
class
CholeskySolveFunctor
{
public:
void
operator
()(
const
Context
&
dev_ctx
,
bool
upper
,
int
M
,
int
N
,
T
*
Adata
,
int
lda
,
T
*
Bdata
,
int
*
devInfo
);
};
template
<
typename
T
,
typename
Context
>
void
CholeskySolveKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
bool
upper
,
DenseTensor
*
out
)
{
// get broadcast dim
std
::
vector
<
int64_t
>
x_bst_dims_vec
;
std
::
vector
<
int64_t
>
y_bst_dims_vec
;
std
::
tie
(
x_bst_dims_vec
,
y_bst_dims_vec
)
=
funcs
::
MatrixGetBroadcastDims
(
x
,
y
);
ScalarArray
x_bst_dims
(
x_bst_dims_vec
);
ScalarArray
y_bst_dims
(
y_bst_dims_vec
);
DenseTensor
y_bst
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
y_bst_dims
);
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
y
,
y_bst_dims
,
&
y_bst
);
// Tensor broadcast to temp 'x_bst' and 'y_bst'
DenseTensor
x_bst
=
phi
::
Empty
<
T
,
Context
>
(
dev_ctx
,
x_bst_dims
);
ExpandKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
x_bst_dims
,
&
x_bst
);
// calculate y_bst's conjugate for complex
DenseTensor
y_bst_conj
=
Conj
<
T
,
Context
>
(
dev_ctx
,
y_bst
);
y_bst_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
y_bst_conj
);
T
*
y_bst_conj_data
=
y_bst_conj
.
data
<
T
>
();
// calculate x_bst's conjugate for complex
DenseTensor
x_bst_conj
=
Conj
<
T
,
Context
>
(
dev_ctx
,
x_bst
);
x_bst_conj
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
x_bst_conj
);
// copy x_bst's conjugate to 'result'
DenseTensor
result
;
Copy
<
Context
>
(
dev_ctx
,
x_bst_conj
,
dev_ctx
.
GetPlace
(),
false
,
&
result
);
T
*
res_data
=
result
.
data
<
T
>
();
// CPU use lapack, GPU use cusolver
int
x_bst_ndim
=
x_bst_dims_vec
.
size
();
int
M
=
static_cast
<
int
>
(
x_bst_dims_vec
[
x_bst_ndim
-
2
]);
int
N
=
static_cast
<
int
>
(
x_bst_dims_vec
[
x_bst_ndim
-
1
]);
int
batchsize
=
product
(
phi
::
slice_ddim
(
x_bst
.
dims
(),
0
,
x_bst_ndim
-
2
));
DenseTensor
info
=
phi
::
Empty
<
int
,
Context
>
(
dev_ctx
,
ScalarArray
({
batchsize
}));
int
*
info_data
=
info
.
data
<
int
>
();
CholeskySolveFunctor
<
T
,
Context
>
functor
;
for
(
int
i
=
0
;
i
<
batchsize
;
++
i
)
{
functor
(
dev_ctx
,
upper
,
M
,
N
,
y_bst_conj_data
+
i
*
M
*
M
,
std
::
max
(
1
,
M
),
res_data
+
i
*
M
*
N
,
info_data
+
i
);
}
// calculate out's conjugate for complex
result
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
result
);
out
->
Resize
(
phi
::
make_ddim
(
x_bst_dims_vec
));
ConjKernel
<
T
,
Context
>
(
dev_ctx
,
result
,
out
);
}
}
// namespace phi
paddle/phi/ops/compat/cholesky_solve_sig.cc
0 → 100644
浏览文件 @
e24ca55e
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
CholeskySolveGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"cholesky_solve_grad"
,
{
"X"
,
"Y"
,
"Out"
,
GradVarName
(
"Out"
)},
{
"upper"
},
{
GradVarName
(
"X"
),
GradVarName
(
"Y"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
cholesky_solve_grad
,
phi
::
CholeskySolveGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录