Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
b91e8eec
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b91e8eec
编写于
9月 24, 2021
作者:
J
jiangcheng
提交者:
GitHub
9月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gradient kernel of det op and slogdet op (#36013)
* add gradient kernel of det op and slogdet op * fix CI APPROVAL problem
上级
787273ed
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
266 addition
and
75 deletion
+266
-75
paddle/fluid/operators/determinant_op.cc
paddle/fluid/operators/determinant_op.cc
+7
-4
paddle/fluid/operators/determinant_op.cu
paddle/fluid/operators/determinant_op.cu
+0
-36
paddle/fluid/operators/determinant_op.h
paddle/fluid/operators/determinant_op.h
+246
-16
python/paddle/fluid/tests/unittests/test_determinant_op.py
python/paddle/fluid/tests/unittests/test_determinant_op.py
+13
-19
未找到文件。
paddle/fluid/operators/determinant_op.cc
浏览文件 @
b91e8eec
...
@@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
...
@@ -48,6 +48,8 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"DeterminantGradOp"
);
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
framework
::
GradVarName
(
"Input"
),
"DeterminantGradOp"
);
framework
::
GradVarName
(
"Input"
),
"DeterminantGradOp"
);
...
@@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel {
...
@@ -117,7 +119,8 @@ class SlogDeterminantGradOp : public framework::OperatorWithKernel {
"SlogDeterminantGradOp"
);
"SlogDeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"SlogDeterminantGradOp"
);
"SlogDeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"SlogDeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
framework
::
GradVarName
(
"Input"
),
"SlogDeterminantGradOp"
);
framework
::
GradVarName
(
"Input"
),
"SlogDeterminantGradOp"
);
...
@@ -179,7 +182,7 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
...
@@ -179,7 +182,7 @@ REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops
::
SlogDeterminantGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
SlogDeterminantGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
slogdeterminant_grad
,
REGISTER_OPERATOR
(
slogdeterminant_grad
,
ops
::
DeterminantGradOp
)
// reuse det grad op
ops
::
Slog
DeterminantGradOp
)
// reuse det grad op
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
slogdeterminant
,
ops
::
SlogDeterminantKernel
<
plat
::
CPUDeviceContext
,
float
>
,
slogdeterminant
,
ops
::
SlogDeterminantKernel
<
plat
::
CPUDeviceContext
,
float
>
,
...
@@ -187,5 +190,5 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -187,5 +190,5 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
slogdeterminant_grad
,
slogdeterminant_grad
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
Slog
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
ops
::
Slog
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/determinant_op.cu
浏览文件 @
b91e8eec
...
@@ -14,42 +14,6 @@ limitations under the License. */
...
@@ -14,42 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
DeterminantGrad
(
const
size_t
numel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
numel
)
{
out
[
tid
]
=
static_cast
<
T
>
(
1
);
}
}
template
<
typename
T
>
class
DeterminantGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
auto
*
dout
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
T
*
dout_data
=
dout
->
data
<
T
>
();
auto
dout_dim
=
vectorize
(
dout
->
dims
());
auto
*
dx
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
dx
->
numel
();
for
(
int64_t
idx
=
0
;
idx
<
numel
;
idx
++
)
{
dx_data
[
idx
]
=
static_cast
<
T
>
(
1
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
namespace
plat
=
paddle
::
platform
;
...
...
paddle/fluid/operators/determinant_op.h
浏览文件 @
b91e8eec
...
@@ -19,7 +19,11 @@
...
@@ -19,7 +19,11 @@
#include <cmath>
#include <cmath>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/matrix_inverse.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -48,11 +52,10 @@ class EigenMatrix<double> {
...
@@ -48,11 +52,10 @@ class EigenMatrix<double> {
inline
int64_t
GetBatchCount
(
const
framework
::
DDim
dims
)
{
inline
int64_t
GetBatchCount
(
const
framework
::
DDim
dims
)
{
int64_t
batch_count
=
1
;
int64_t
batch_count
=
1
;
auto
dim_size
=
dims
.
size
();
auto
dim_size
=
dims
.
size
();
PADDLE_ENFORCE_GT
(
dim_size
,
2
,
PADDLE_ENFORCE_GE
(
platform
::
errors
::
InvalidArgument
(
dim_size
,
2
,
"To get the number of batch square matrices, "
platform
::
errors
::
InvalidArgument
(
"the size of dimension should greater than 2."
,
"the input matrix dimension size should greater than 2."
));
dim_size
));
// Cumulative multiplying each dimension until the last 2 to get the batch
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// count,
...
@@ -77,7 +80,7 @@ struct DeterminantFunctor {
...
@@ -77,7 +80,7 @@ struct DeterminantFunctor {
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
end_iter
);
// get every square matrix data
end_iter
);
// get every square matrix data
Eigen
::
MatrixXf
matrix
(
rank
,
rank
);
typename
EigenMatrix
<
T
>::
MatrixType
matrix
(
rank
,
rank
);
for
(
int64_t
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
rank
;
++
j
)
{
for
(
int64_t
j
=
0
;
j
<
rank
;
++
j
)
{
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
...
@@ -109,41 +112,169 @@ class DeterminantKernel : public framework::OpKernel<T> {
...
@@ -109,41 +112,169 @@ class DeterminantKernel : public framework::OpKernel<T> {
"the input matrix should be square matrix."
));
"the input matrix should be square matrix."
));
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
DeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
DeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
auto
output_dims
=
framework
::
slice_ddim
(
input
->
dims
(),
0
,
input_dim_size
-
2
);
if
(
input_dim_size
>
2
)
{
if
(
input_dim_size
>
2
)
{
auto
output_dims
=
framework
::
slice_ddim
(
input
->
dims
(),
0
,
input_dim_size
-
2
);
output
->
Resize
(
output_dims
);
output
->
Resize
(
output_dims
);
}
else
{
// when input is a two-dimension matrix, The det value is a number.
output
->
Resize
({
1
});
}
}
VLOG
(
2
)
<<
"output dim:"
<<
output
->
dims
();
VLOG
(
2
)
<<
"output dim:"
<<
output
->
dims
();
}
}
};
};
template
<
typename
T
>
struct
FoundZeroFunctor
{
FoundZeroFunctor
(
const
T
*
x
,
int64_t
numel
,
bool
*
res
)
:
x_
(
x
),
numel_
(
numel
),
res_
(
res
)
{}
HOSTDEVICE
void
operator
()(
size_t
idx
)
const
{
if
(
*
res_
||
idx
>=
static_cast
<
size_t
>
(
numel_
))
{
// founded zero number
return
;
}
*
res_
=
(
x_
[
idx
]
==
static_cast
<
T
>
(
0
));
}
const
T
*
x_
;
int64_t
numel_
;
bool
*
res_
;
};
template
<
typename
DeviceContext
,
typename
T
>
inline
bool
CheckMatrixInvertible
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
det
)
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
numel
=
det
->
numel
();
framework
::
Tensor
dev_tensor
;
auto
*
data
=
dev_tensor
.
mutable_data
<
bool
>
({
1
},
ctx
.
GetPlace
());
// set false
math
::
SetConstant
<
DeviceContext
,
bool
>
zero
;
zero
(
dev_ctx
,
&
dev_tensor
,
false
);
// find whether zero
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
FoundZeroFunctor
<
T
>
functor
(
det
->
data
<
T
>
(),
numel
,
data
);
for_range
(
functor
);
// copy to host
dev_ctx
.
Wait
();
framework
::
Tensor
cpu_tensor
;
framework
::
TensorCopy
(
dev_tensor
,
platform
::
CPUPlace
(),
&
cpu_tensor
);
// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto
*
res
=
cpu_tensor
.
data
<
bool
>
();
return
!
(
*
res
);
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
DeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
DeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
"Not support DeterminantGrad at this time."
));
const
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input"
);
const
auto
*
det
=
context
.
Input
<
framework
::
Tensor
>
(
"Out"
);
const
auto
*
grad
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
ddet
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
auto
input_dims_size
=
input
->
dims
().
size
();
if
(
input_dims_size
>
2
)
{
PADDLE_ENFORCE_EQ
(
grad
->
dims
().
size
()
+
2
,
input_dims_size
,
platform
::
errors
::
InvalidArgument
(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d"
,
input_dims_size
-
grad
->
dims
().
size
()));
}
else
if
(
input_dims_size
==
2
)
{
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ
(
grad
->
dims
().
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d"
,
input_dims_size
-
grad
->
dims
().
size
()));
}
else
{
// checked in forward, pass
}
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if
(
!
CheckMatrixInvertible
<
DeviceContext
,
T
>
(
context
,
det
))
{
// The matrix is not invertible
VLOG
(
3
)
<<
"The input matrix not invertible!"
;
ddet
->
Resize
(
input
->
dims
());
ddet
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
zero
(
dev_ctx
,
ddet
,
static_cast
<
T
>
(
0.0
f
));
return
;
}
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>
helper
(
context
);
// First: inverse(A)
framework
::
Tensor
inverse_A
;
// A must be square matrices!
inverse_A
.
Resize
(
input
->
dims
());
inverse_A
.
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
MatrixInverseFunctor
<
DeviceContext
,
T
>
mat_inv
;
mat_inv
(
dev_ctx
,
*
input
,
&
inverse_A
);
VLOG
(
3
)
<<
"inverse(A) dims: "
<<
inverse_A
.
dims
();
// Second: inverse(A).transpose(-2, -1)
framework
::
Tensor
transpose_inverse_A
=
helper
.
Transpose
(
inverse_A
);
VLOG
(
3
)
<<
"(dA * |A|).transpose(-2, -1) dims: "
<<
transpose_inverse_A
.
dims
();
// Third: dA * |A|
auto
mul_dA_detA
=
helper
.
Mul
(
*
grad
,
*
det
);
VLOG
(
3
)
<<
"dA * |A| dims: "
<<
mul_dA_detA
.
dims
();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto
unsqueeze1
=
helper
.
Unsqueeze
(
mul_dA_detA
,
-
1
);
auto
unsqueeze2
=
helper
.
Unsqueeze
(
unsqueeze1
,
-
2
);
VLOG
(
3
)
<<
"unsqueezed(dA * |A|) dims: "
<<
unsqueeze2
.
dims
();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto
res
=
helper
.
Mul
(
unsqueeze2
,
transpose_inverse_A
);
VLOG
(
3
)
<<
"unsqueeze(dA * |A|) * inverse(A) dims: "
<<
res
.
dims
();
framework
::
TensorCopy
(
res
,
context
.
GetPlace
(),
ddet
);
ddet
->
Resize
(
input
->
dims
());
VLOG
(
3
)
<<
"d|A| dims: "
<<
ddet
->
dims
();
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
SlogDeterminantFunctor
{
struct
SlogDeterminantFunctor
{
void
operator
()(
const
Tensor
&
input
,
const
framework
::
ExecutionContext
ctx
,
void
operator
()(
const
Tensor
&
input
,
const
framework
::
ExecutionContext
ctx
,
int
rank
,
in
t
batch_count
,
Tensor
*
output
)
{
int
64_t
rank
,
int64_
t
batch_count
,
Tensor
*
output
)
{
std
::
vector
<
T
>
input_vec
;
std
::
vector
<
T
>
input_vec
;
std
::
vector
<
T
>
sign_vec
;
std
::
vector
<
T
>
sign_vec
;
std
::
vector
<
T
>
log_vec
;
std
::
vector
<
T
>
log_vec
;
std
::
vector
<
T
>
output_vec
;
std
::
vector
<
T
>
output_vec
;
framework
::
TensorToVector
(
input
,
ctx
.
device_context
(),
&
input_vec
);
framework
::
TensorToVector
(
input
,
ctx
.
device_context
(),
&
input_vec
);
for
(
int
i
=
0
;
i
<
batch_count
;
++
i
)
{
// maybe can be parallel
for
(
int
64_t
i
=
0
;
i
<
batch_count
;
++
i
)
{
// maybe can be parallel
auto
begin_iter
=
input_vec
.
begin
()
+
i
*
rank
*
rank
;
auto
begin_iter
=
input_vec
.
begin
()
+
i
*
rank
*
rank
;
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
end_iter
);
// get every square matrix data
end_iter
);
// get every square matrix data
typename
EigenMatrix
<
T
>::
MatrixType
matrix
(
rank
,
rank
);
typename
EigenMatrix
<
T
>::
MatrixType
matrix
(
rank
,
rank
);
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int
64_t
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int
j
=
0
;
j
<
rank
;
++
j
)
{
for
(
int
64_t
j
=
0
;
j
<
rank
;
++
j
)
{
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
}
}
}
}
...
@@ -185,6 +316,10 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
...
@@ -185,6 +316,10 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
SlogDeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
SlogDeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
std
::
vector
<
int
>
output_dim_vec
(
input_dim
.
begin
(),
input_dim
.
end
()
-
2
);
std
::
vector
<
int
>
output_dim_vec
(
input_dim
.
begin
(),
input_dim
.
end
()
-
2
);
if
(
input_dim
.
size
()
==
static_cast
<
size_t
>
(
2
))
{
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec
=
{
1
};
}
output_dim_vec
.
insert
(
output_dim_vec
.
begin
(),
output_dim_vec
.
insert
(
output_dim_vec
.
begin
(),
2
);
// make the output dims as same as numpy
2
);
// make the output dims as same as numpy
auto
output_dims
=
framework
::
make_ddim
(
output_dim_vec
);
auto
output_dims
=
framework
::
make_ddim
(
output_dim_vec
);
...
@@ -197,8 +332,103 @@ template <typename DeviceContext, typename T>
...
@@ -197,8 +332,103 @@ template <typename DeviceContext, typename T>
class
SlogDeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SlogDeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
"Not support SlogDeterminantGrad at this time."
));
const
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input"
);
const
auto
*
slogdet
=
context
.
Input
<
framework
::
Tensor
>
(
"Out"
);
const
auto
*
grad
=
context
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dslogdet
=
context
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
PADDLE_ENFORCE_EQ
(
grad
->
dims
()[
0
],
2
,
platform
::
errors
::
InvalidArgument
(
"The grad tensor of SlogDet should contain two"
" grad: sign and absslogdet, but here %ld."
,
grad
->
dims
()[
0
]));
if
(
input
->
dims
().
size
()
>
2
)
{
PADDLE_ENFORCE_EQ
(
grad
->
dims
().
size
()
+
1
,
input
->
dims
().
size
(),
platform
::
errors
::
InvalidArgument
(
"The grad tensor of slogdet dims size should 1 less than"
" input tensor's, but here differ %d"
,
input
->
dims
().
size
()
-
grad
->
dims
().
size
()));
}
// Check Whether the matrix is invertible
// (matrix A not invertible) == (absslogdet(A)=0)
auto
slogdet_vec
=
slogdet
->
Split
(
1
,
0
);
auto
absslogdet_val
=
slogdet_vec
[
0
];
if
(
!
CheckMatrixInvertible
<
DeviceContext
,
T
>
(
context
,
&
absslogdet_val
))
{
// The matrix is not invertible
VLOG
(
3
)
<<
"The input matrix not invertible!"
;
dslogdet
->
Resize
(
input
->
dims
());
dslogdet
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
zero
(
dev_ctx
,
dslogdet
,
std
::
numeric_limits
<
T
>::
quiet_NaN
());
return
;
}
// The matrix is invertible
// let sl|A| = SlogDeterminant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set dsl|A| = unsqueeze(dslA, [-1, -2]) *
// inverse(A).conj().transpose(-2, -1)
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>
helper
(
context
);
// First: inverse(A)
framework
::
Tensor
inverse_A
;
// A must be square matrices!
inverse_A
.
Resize
(
input
->
dims
());
inverse_A
.
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
MatrixInverseFunctor
<
DeviceContext
,
T
>
mat_inv
;
mat_inv
(
dev_ctx
,
*
input
,
&
inverse_A
);
VLOG
(
3
)
<<
"inverse(A) dims: "
<<
inverse_A
.
dims
();
// Second: inverse(A).conj()
framework
::
Tensor
conj_inverse_A
;
conj_inverse_A
.
Resize
(
inverse_A
.
dims
());
auto
numel
=
input
->
numel
();
auto
*
conj_data
=
conj_inverse_A
.
mutable_data
<
T
>
(
context
.
GetPlace
(),
size_t
(
numel
*
sizeof
(
T
)));
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
numel
);
math
::
ConjFunctor
<
T
>
functor
(
inverse_A
.
data
<
T
>
(),
numel
,
conj_data
);
for_range
(
functor
);
VLOG
(
3
)
<<
"inverse(A).conj() dims: "
<<
conj_inverse_A
.
dims
();
// Third: inverse(A).conj().transpose(-2, -1)
framework
::
Tensor
transpose_inverse_A
=
helper
.
Transpose
(
conj_inverse_A
);
VLOG
(
3
)
<<
"inverse(A).conj().transpose(-2, -1) dims: "
<<
transpose_inverse_A
.
dims
();
// Fourth: split grad value to [sign_grad, absslogdet_grad]
auto
grad_vec
=
grad
->
Split
(
1
,
0
);
auto
det_grad
=
grad_vec
[
1
];
// remmove useless first dimension
int
det_grad_size
=
det_grad
.
dims
().
size
();
std
::
vector
<
int
>
det_grad_vec
;
for
(
int
i
=
1
;
i
<
det_grad_size
;
++
i
)
{
det_grad_vec
.
emplace_back
(
det_grad
.
dims
()[
i
]);
}
det_grad
.
Resize
(
det_grad
.
dims
().
reshape
(
det_grad_vec
));
// Fifth: unsqueeze(dslA, [-1, -2])
auto
unsqueeze1
=
helper
.
Unsqueeze
(
det_grad
,
-
1
);
auto
unsqueeze2
=
helper
.
Unsqueeze
(
unsqueeze1
,
-
2
);
VLOG
(
3
)
<<
"unsqueezed(dslA, [-1, -2]) dims: "
<<
unsqueeze2
.
dims
();
// Finally: unsqueeze(dslA) * inverse(A)
auto
res
=
helper
.
Mul
(
unsqueeze2
,
transpose_inverse_A
);
VLOG
(
3
)
<<
"unsqueeze(dslA) * inverse(A) dims: "
<<
res
.
dims
();
framework
::
TensorCopy
(
res
,
context
.
GetPlace
(),
dslogdet
);
dslogdet
->
Resize
(
input
->
dims
());
VLOG
(
3
)
<<
"dsl|A| dims: "
<<
dslogdet
->
dims
();
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_determinant_op.py
浏览文件 @
b91e8eec
...
@@ -16,7 +16,7 @@ from __future__ import print_function
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
,
skip_check_grad_ci
from
op_test
import
OpTest
import
paddle
import
paddle
import
paddle.nn.functional
as
F
import
paddle.nn.functional
as
F
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -26,7 +26,6 @@ import paddle.tensor as tensor
...
@@ -26,7 +26,6 @@ import paddle.tensor as tensor
paddle
.
enable_static
()
paddle
.
enable_static
()
@
skip_check_grad_ci
(
reason
=
"determinant grad is in progress."
)
class
TestDeterminantOp
(
OpTest
):
class
TestDeterminantOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
init_data
()
self
.
init_data
()
...
@@ -37,11 +36,11 @@ class TestDeterminantOp(OpTest):
...
@@ -37,11 +36,11 @@ class TestDeterminantOp(OpTest):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
pass
self
.
check_grad
([
'Input'
],
[
'Out'
])
def
init_data
(
self
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
,
3
).
astype
(
'float64'
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
5
,
5
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
...
@@ -49,30 +48,25 @@ class TestDeterminantOp(OpTest):
...
@@ -49,30 +48,25 @@ class TestDeterminantOp(OpTest):
class
TestDeterminantOpCase1
(
TestDeterminantOp
):
class
TestDeterminantOpCase1
(
TestDeterminantOp
):
def
init_data
(
self
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
case
=
np
.
random
.
rand
(
10
,
10
).
astype
(
'float32'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
def
test_check_grad
(
self
):
pass
class
TestDeterminantOpCase2
(
TestDeterminantOp
):
class
TestDeterminantOpCase2
(
TestDeterminantOp
):
def
init_data
(
self
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
4
,
2
,
4
,
4
).
astype
(
'float64'
)
# not invertible matrix
self
.
case
=
np
.
ones
([
4
,
2
,
4
,
4
]).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
def
test_check_grad
(
self
):
pass
class
TestDeterminantAPI
(
unittest
.
TestCase
):
class
TestDeterminantAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
shape
=
[
3
,
3
,
3
,
3
]
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
shape
=
[
3
,
3
,
5
,
5
]
self
.
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
np
.
float32
)
self
.
place
=
paddle
.
CPUPlace
()
self
.
place
=
paddle
.
CPUPlace
()
def
test_api_static
(
self
):
def
test_api_static
(
self
):
...
@@ -96,7 +90,6 @@ class TestDeterminantAPI(unittest.TestCase):
...
@@ -96,7 +90,6 @@ class TestDeterminantAPI(unittest.TestCase):
paddle
.
enable_static
()
paddle
.
enable_static
()
@
skip_check_grad_ci
(
reason
=
"slogdeterminant grad is in progress."
)
class
TestSlogDeterminantOp
(
OpTest
):
class
TestSlogDeterminantOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"slogdeterminant"
self
.
op_type
=
"slogdeterminant"
...
@@ -107,11 +100,12 @@ class TestSlogDeterminantOp(OpTest):
...
@@ -107,11 +100,12 @@ class TestSlogDeterminantOp(OpTest):
self
.
check_output
()
self
.
check_output
()
def
test_check_grad
(
self
):
def
test_check_grad
(
self
):
pass
# the slog det's grad value is always huge
self
.
check_grad
([
'Input'
],
[
'Out'
],
max_relative_error
=
0.1
)
def
init_data
(
self
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
'float64'
)
self
.
case
=
np
.
random
.
rand
(
4
,
5
,
5
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
case
))
self
.
target
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
case
))
...
@@ -126,9 +120,9 @@ class TestSlogDeterminantOpCase1(TestSlogDeterminantOp):
...
@@ -126,9 +120,9 @@ class TestSlogDeterminantOpCase1(TestSlogDeterminantOp):
class
TestSlogDeterminantAPI
(
unittest
.
TestCase
):
class
TestSlogDeterminantAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
shape
=
[
3
,
3
,
3
,
3
]
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
self
.
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
shape
=
[
3
,
3
,
5
,
5
]
self
.
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
np
.
float32
)
self
.
place
=
paddle
.
CPUPlace
()
self
.
place
=
paddle
.
CPUPlace
()
def
test_api_static
(
self
):
def
test_api_static
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录