Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9ce45ddd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ce45ddd
编写于
9月 22, 2021
作者:
H
huangxu96
提交者:
GitHub
9月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Det &Slogdet (#34992)
Add new API : paddle.linalg.det & paddle.linalg.slogdet API Alias:paddle.det& paddle.slogdet
上级
00e0e358
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
734 addition
and
1 deletion
+734
-1
paddle/fluid/operators/determinant_op.cc
paddle/fluid/operators/determinant_op.cc
+191
-0
paddle/fluid/operators/determinant_op.cu
paddle/fluid/operators/determinant_op.cu
+72
-0
paddle/fluid/operators/determinant_op.h
paddle/fluid/operators/determinant_op.h
+206
-0
python/paddle/__init__.py
python/paddle/__init__.py
+2
-0
python/paddle/fluid/tests/unittests/test_determinant_op.py
python/paddle/fluid/tests/unittests/test_determinant_op.py
+155
-0
python/paddle/linalg.py
python/paddle/linalg.py
+4
-0
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+104
-1
未找到文件。
paddle/fluid/operators/determinant_op.cc
0 → 100644
浏览文件 @
9ce45ddd
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/determinant_op.h"
namespace
paddle
{
namespace
operators
{
class
DeterminantOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"determinant"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"determinant"
);
}
};
class
DeterminantOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(Tensor) The input tensor of determinant."
);
AddOutput
(
"Out"
,
"(Tensor) The output Tensor containing the determinant"
"value of a square matrix or batches of square matrices "
);
AddComment
(
R"DOC(
Determinant Operator.)DOC"
);
}
};
class
DeterminantGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"DeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
framework
::
GradVarName
(
"Input"
),
"DeterminantGradOp"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input"
),
ctx
->
GetInputDim
(
"Input"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
DeterminantGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"determinant_grad"
);
grad_op
->
SetInput
(
"Input"
,
this
->
Input
(
"Input"
));
grad_op
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Input"
),
this
->
InputGrad
(
"Input"
));
grad_op
->
SetAttrMap
(
this
->
Attrs
());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER
(
DeterminantGradNoNeedBufferVarsInferer
,
"Input"
);
class
SlogDeterminantOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"determinant"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"determinant"
);
}
};
class
SlogDeterminantOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"(Tensor) The input tensor of SlogDeterminant."
);
AddOutput
(
"Out"
,
"(Tensor) The output tensor containing the sign of the"
"determinant and the natural logarithm"
"of the absolute value of determinant,"
);
AddComment
(
R"DOC(
SlogDeterminant Operator.)DOC"
);
}
};
class
SlogDeterminantGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"SlogDeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Out"
),
"Input"
,
"Out"
,
"SlogDeterminantGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Input"
)),
"Output"
,
framework
::
GradVarName
(
"Input"
),
"SlogDeterminantGradOp"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input"
),
ctx
->
GetInputDim
(
"Input"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
SlogDeterminantGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"slogdeterminant_grad"
);
grad_op
->
SetInput
(
"Input"
,
this
->
Input
(
"Input"
));
grad_op
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"Input"
),
this
->
InputGrad
(
"Input"
));
grad_op
->
SetAttrMap
(
this
->
Attrs
());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER
(
SlogDeterminantGradNoNeedBufferVarsInferer
,
"Input"
);
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OPERATOR
(
determinant
,
ops
::
DeterminantOp
,
ops
::
DeterminantOpMaker
,
ops
::
DeterminantGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
DeterminantGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
determinant_grad
,
ops
::
DeterminantGradOp
)
REGISTER_OP_CPU_KERNEL
(
determinant
,
ops
::
DeterminantKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
DeterminantKernel
<
plat
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
determinant_grad
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
REGISTER_OPERATOR
(
slogdeterminant
,
ops
::
SlogDeterminantOp
,
ops
::
SlogDeterminantOpMaker
,
ops
::
SlogDeterminantGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SlogDeterminantGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
slogdeterminant_grad
,
ops
::
DeterminantGradOp
)
// reuse det grad op
REGISTER_OP_CPU_KERNEL
(
slogdeterminant
,
ops
::
SlogDeterminantKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
SlogDeterminantKernel
<
plat
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
slogdeterminant_grad
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
float
>
,
ops
::
DeterminantGradKernel
<
plat
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/determinant_op.cu
0 → 100644
浏览文件 @
9ce45ddd
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
DeterminantGrad
(
const
size_t
numel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
tid
<
numel
)
{
out
[
tid
]
=
static_cast
<
T
>
(
1
);
}
}
template
<
typename
T
>
class
DeterminantGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
auto
*
dout
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
const
T
*
dout_data
=
dout
->
data
<
T
>
();
auto
dout_dim
=
vectorize
(
dout
->
dims
());
auto
*
dx
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int64_t
numel
=
dx
->
numel
();
for
(
int64_t
idx
=
0
;
idx
<
numel
;
idx
++
)
{
dx_data
[
idx
]
=
static_cast
<
T
>
(
1
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
determinant
,
ops
::
DeterminantKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DeterminantKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
determinant_grad
,
ops
::
DeterminantGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
DeterminantGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
slogdeterminant
,
ops
::
SlogDeterminantKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SlogDeterminantKernel
<
plat
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
slogdeterminant_grad
,
ops
::
SlogDeterminantGradKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SlogDeterminantGradKernel
<
plat
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/determinant_op.h
0 → 100644
浏览文件 @
9ce45ddd
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Eigen/Dense>
#include <Eigen/LU>
#include <algorithm>
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
T
sign
(
T
val
)
{
return
static_cast
<
T
>
(
T
(
0
)
<
val
)
-
(
val
<
T
(
0
));
}
template
<
typename
T
>
class
EigenMatrix
{};
template
<
>
class
EigenMatrix
<
float
>
{
public:
using
MatrixType
=
Eigen
::
MatrixXf
;
};
template
<
>
class
EigenMatrix
<
double
>
{
public:
using
MatrixType
=
Eigen
::
MatrixXd
;
};
inline
int64_t
GetBatchCount
(
const
framework
::
DDim
dims
)
{
int64_t
batch_count
=
1
;
auto
dim_size
=
dims
.
size
();
PADDLE_ENFORCE_GT
(
dim_size
,
2
,
platform
::
errors
::
InvalidArgument
(
"To get the number of batch square matrices, "
"the size of dimension should greater than 2."
,
dim_size
));
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
// 9.
for
(
int64_t
i
=
0
;
i
<
dims
.
size
()
-
2
;
i
++
)
{
batch_count
*=
dims
[
i
];
}
return
batch_count
;
}
template
<
typename
T
>
struct
DeterminantFunctor
{
void
operator
()(
const
Tensor
&
input
,
const
framework
::
ExecutionContext
ctx
,
int64_t
rank
,
int64_t
batch_count
,
Tensor
*
output
)
{
std
::
vector
<
T
>
input_vec
;
std
::
vector
<
T
>
output_vec
;
framework
::
TensorToVector
(
input
,
ctx
.
device_context
(),
&
input_vec
);
for
(
int64_t
i
=
0
;
i
<
batch_count
;
++
i
)
{
// maybe can be parallel
auto
begin_iter
=
input_vec
.
begin
()
+
i
*
rank
*
rank
;
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
end_iter
);
// get every square matrix data
Eigen
::
MatrixXf
matrix
(
rank
,
rank
);
for
(
int64_t
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int64_t
j
=
0
;
j
<
rank
;
++
j
)
{
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
}
}
output_vec
.
push_back
(
matrix
.
determinant
());
}
framework
::
TensorFromVector
(
output_vec
,
output
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DeterminantKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input"
);
auto
input_dim
=
vectorize
(
input
->
dims
());
auto
input_dim_size
=
input_dim
.
size
();
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
batch_count
=
GetBatchCount
(
input
->
dims
());
VLOG
(
2
)
<<
"input dim:"
<<
input
->
dims
();
PADDLE_ENFORCE_GE
(
input_dim_size
,
2
,
platform
::
errors
::
InvalidArgument
(
"the input matrix dimension size should greater than 2."
));
PADDLE_ENFORCE_EQ
(
input_dim
[
input_dim_size
-
1
],
input_dim
[
input_dim_size
-
2
],
platform
::
errors
::
InvalidArgument
(
"the input matrix should be square matrix."
));
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
DeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
if
(
input_dim_size
>
2
)
{
auto
output_dims
=
framework
::
slice_ddim
(
input
->
dims
(),
0
,
input_dim_size
-
2
);
output
->
Resize
(
output_dims
);
}
VLOG
(
2
)
<<
"output dim:"
<<
output
->
dims
();
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
DeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Not support DeterminantGrad at this time."
));
}
};
template
<
typename
T
>
struct
SlogDeterminantFunctor
{
void
operator
()(
const
Tensor
&
input
,
const
framework
::
ExecutionContext
ctx
,
int
rank
,
int
batch_count
,
Tensor
*
output
)
{
std
::
vector
<
T
>
input_vec
;
std
::
vector
<
T
>
sign_vec
;
std
::
vector
<
T
>
log_vec
;
std
::
vector
<
T
>
output_vec
;
framework
::
TensorToVector
(
input
,
ctx
.
device_context
(),
&
input_vec
);
for
(
int
i
=
0
;
i
<
batch_count
;
++
i
)
{
// maybe can be parallel
auto
begin_iter
=
input_vec
.
begin
()
+
i
*
rank
*
rank
;
auto
end_iter
=
input_vec
.
begin
()
+
(
i
+
1
)
*
rank
*
rank
;
std
::
vector
<
T
>
sub_vec
(
begin_iter
,
end_iter
);
// get every square matrix data
typename
EigenMatrix
<
T
>::
MatrixType
matrix
(
rank
,
rank
);
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
for
(
int
j
=
0
;
j
<
rank
;
++
j
)
{
matrix
(
i
,
j
)
=
sub_vec
[
rank
*
i
+
j
];
}
}
VLOG
(
2
)
<<
"det value: "
<<
matrix
.
determinant
();
VLOG
(
2
)
<<
"matrix val: "
<<
matrix
;
auto
det_val
=
matrix
.
determinant
();
sign_vec
.
push_back
(
sign
(
det_val
));
det_val
>=
0
?
log_vec
.
push_back
(
std
::
log
(
det_val
))
:
log_vec
.
push_back
(
std
::
log
(
std
::
abs
(
det_val
)));
// for computing log value of a negative value.
}
// merge sign_vec and log_vec as final output_vec
output_vec
.
insert
(
output_vec
.
end
(),
sign_vec
.
begin
(),
sign_vec
.
end
());
output_vec
.
insert
(
output_vec
.
end
(),
log_vec
.
begin
(),
log_vec
.
end
());
framework
::
TensorFromVector
(
output_vec
,
output
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SlogDeterminantKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
framework
::
Tensor
>
(
"Input"
);
auto
input_dim
=
vectorize
(
input
->
dims
());
auto
input_dim_size
=
input_dim
.
size
();
auto
*
output
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
batch_count
=
GetBatchCount
(
input
->
dims
());
VLOG
(
2
)
<<
"input dim:"
<<
input
->
dims
();
PADDLE_ENFORCE_GE
(
input_dim_size
,
2
,
platform
::
errors
::
InvalidArgument
(
"the input matrix dimension size should greater than 2."
));
PADDLE_ENFORCE_EQ
(
input_dim
[
input_dim_size
-
1
],
input_dim
[
input_dim_size
-
2
],
platform
::
errors
::
InvalidArgument
(
"the input matrix should be square matrix."
));
auto
rank
=
input_dim
[
input_dim_size
-
1
];
// square matrix length
SlogDeterminantFunctor
<
T
>
()(
*
input
,
context
,
rank
,
batch_count
,
output
);
std
::
vector
<
int
>
output_dim_vec
(
input_dim
.
begin
(),
input_dim
.
end
()
-
2
);
output_dim_vec
.
insert
(
output_dim_vec
.
begin
(),
2
);
// make the output dims as same as numpy
auto
output_dims
=
framework
::
make_ddim
(
output_dim_vec
);
output
->
Resize
(
output_dims
);
VLOG
(
2
)
<<
"output dim:"
<<
output
->
dims
();
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SlogDeterminantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Not support SlogDeterminantGrad at this time."
));
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/__init__.py
浏览文件 @
9ce45ddd
...
...
@@ -101,6 +101,8 @@ from .tensor.linalg import cholesky # noqa: F401
from
.tensor.linalg
import
bmm
# noqa: F401
from
.tensor.linalg
import
histogram
# noqa: F401
from
.tensor.linalg
import
mv
# noqa: F401
from
.tensor.linalg
import
det
# noqa: F401
from
.tensor.linalg
import
slogdet
# noqa: F401
from
.tensor.linalg
import
multi_dot
# noqa: F401
from
.tensor.linalg
import
matrix_power
# noqa: F401
from
.tensor.linalg
import
svd
# noqa: F401
...
...
python/paddle/fluid/tests/unittests/test_determinant_op.py
0 → 100644
浏览文件 @
9ce45ddd
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle
import
paddle.nn.functional
as
F
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.tensor
as
tensor
paddle
.
enable_static
()
@
skip_check_grad_ci
(
reason
=
"determinant grad is in progress."
)
class
TestDeterminantOp
(
OpTest
):
def
setUp
(
self
):
self
.
init_data
()
self
.
op_type
=
"determinant"
self
.
outputs
=
{
'Out'
:
self
.
target
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
pass
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
,
3
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
class
TestDeterminantOpCase1
(
TestDeterminantOp
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
def
test_check_grad
(
self
):
pass
class
TestDeterminantOpCase2
(
TestDeterminantOp
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
4
,
2
,
4
,
4
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
linalg
.
det
(
self
.
case
)
def
test_check_grad
(
self
):
pass
class
TestDeterminantAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
[
3
,
3
,
3
,
3
]
np
.
random
.
seed
(
0
)
self
.
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
place
=
paddle
.
CPUPlace
()
def
test_api_static
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
fluid
.
data
(
'X'
,
self
.
shape
)
out
=
paddle
.
linalg
.
det
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x
},
fetch_list
=
[
out
])
out_ref
=
np
.
linalg
.
det
(
self
.
x
)
for
out
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out
,
out_ref
,
rtol
=
1e-03
),
True
)
def
test_api_dygraph
(
self
):
paddle
.
disable_static
(
self
.
place
)
x_tensor
=
paddle
.
to_tensor
(
self
.
x
)
out
=
paddle
.
linalg
.
det
(
x_tensor
)
out_ref
=
np
.
linalg
.
det
(
self
.
x
)
self
.
assertEqual
(
np
.
allclose
(
out
.
numpy
(),
out_ref
,
rtol
=
1e-03
),
True
)
paddle
.
enable_static
()
@
skip_check_grad_ci
(
reason
=
"slogdeterminant grad is in progress."
)
class
TestSlogDeterminantOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"slogdeterminant"
self
.
init_data
()
self
.
outputs
=
{
'Out'
:
self
.
target
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
pass
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
'float64'
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
case
))
class
TestSlogDeterminantOpCase1
(
TestSlogDeterminantOp
):
def
init_data
(
self
):
np
.
random
.
seed
(
0
)
self
.
case
=
np
.
random
.
rand
(
2
,
2
,
5
,
5
).
astype
(
np
.
float32
)
self
.
inputs
=
{
'Input'
:
self
.
case
}
self
.
target
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
case
))
class
TestSlogDeterminantAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
shape
=
[
3
,
3
,
3
,
3
]
np
.
random
.
seed
(
0
)
self
.
x
=
np
.
random
.
rand
(
3
,
3
,
3
,
3
).
astype
(
np
.
float32
)
self
.
place
=
paddle
.
CPUPlace
()
def
test_api_static
(
self
):
paddle
.
enable_static
()
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
fluid
.
data
(
'X'
,
self
.
shape
)
out
=
paddle
.
linalg
.
slogdet
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x
},
fetch_list
=
[
out
])
out_ref
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
x
))
for
out
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out
,
out_ref
,
rtol
=
1e-03
),
True
)
def
test_api_dygraph
(
self
):
paddle
.
disable_static
(
self
.
place
)
x_tensor
=
paddle
.
to_tensor
(
self
.
x
)
out
=
paddle
.
linalg
.
slogdet
(
x_tensor
)
out_ref
=
np
.
array
(
np
.
linalg
.
slogdet
(
self
.
x
))
self
.
assertEqual
(
np
.
allclose
(
out
.
numpy
(),
out_ref
,
rtol
=
1e-03
),
True
)
paddle
.
enable_static
()
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/linalg.py
浏览文件 @
9ce45ddd
...
...
@@ -22,6 +22,8 @@ from .tensor.linalg import multi_dot # noqa: F401
from
.tensor.linalg
import
matrix_rank
from
.tensor.linalg
import
svd
from
.tensor.linalg
import
eigh
# noqa: F401
from
.tensor.linalg
import
det
from
.tensor.linalg
import
slogdet
from
.tensor.linalg
import
pinv
__all__
=
[
...
...
@@ -34,6 +36,8 @@ __all__ = [
'matrix_rank'
,
'svd'
,
'matrix_power'
,
'det'
,
'slogdet'
,
'eigh'
,
'pinv'
]
python/paddle/tensor/linalg.py
浏览文件 @
9ce45ddd
...
...
@@ -14,7 +14,7 @@
import
numpy
as
np
from
..fluid.layer_helper
import
LayerHelper
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
from
..fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
from
..fluid.framework
import
in_dygraph_mode
,
_varbase_creator
,
Variable
from
..fluid.layers
import
transpose
,
cast
# noqa: F401
...
...
@@ -1351,6 +1351,109 @@ def mv(x, vec, name=None):
return
out
def
det
(
x
):
"""
Calculates determinant value of a square matrix or batches of square matrices.
Args:
x (Tensor): input (Tensor): the input matrix of size `(n, n)` or the batch of matrices of size
`(*, n, n)` where `*` is one or more batch dimensions.
Returns:
y (Tensor):the determinant value of a square matrix or batches of square matrices.
Example:
.. code-block:: python
import paddle
x = paddle.randn([3,3,3])
A = paddle.det(x)
print(A)
# [ 0.02547996, 2.52317095, -6.15900707])
"""
if
in_dygraph_mode
():
return
core
.
ops
.
determinant
(
x
)
check_dtype
(
x
.
dtype
,
'Input'
,
[
'float32'
,
'float64'
],
'det'
)
input_shape
=
list
(
x
.
shape
)
assert
len
(
input_shape
)
>=
2
,
\
"The x must be at least 2-dimensional, "
\
"but received Input x's dimensional: %s.
\n
"
%
\
len
(
input_shape
)
assert
(
input_shape
[
-
1
]
==
input_shape
[
-
2
]),
\
"Expect squared input,"
\
"but received %s by %s matrix.
\n
"
\
%
(
input_shape
[
-
2
],
input_shape
[
-
1
])
\
helper
=
LayerHelper
(
'determinant'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'determinant'
,
inputs
=
{
'Input'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]})
return
out
def
slogdet
(
x
):
"""
Calculates the sign and natural logarithm of the absolute value of a square matrix's or batches square matrices' determinant.
The determinant can be computed with ``sign * exp(logabsdet)
Supports input of float, double
Note that for matrices that have zero determinant, this returns ``(0, -inf)``
Args:
x (Tensor): the batch of matrices of size :math:`(*, n, n)`
where math:`*` is one or more batch dimensions.
Returns:
y (Tensor): A tensor containing the sign of the determinant and the natural logarithm
of the absolute value of determinant, respectively.
Example:
.. code-block:: python
import paddle
x = paddle.randn([3,3,3])
A = paddle.slogdet(x)
print(A)
# [[ 1. , 1. , -1. ],
# [-0.98610914, -0.43010661, -0.10872950]])
"""
if
in_dygraph_mode
():
return
core
.
ops
.
slogdeterminant
(
x
)
check_dtype
(
x
.
dtype
,
'Input'
,
[
'float32'
,
'float64'
],
'slogdet'
)
input_shape
=
list
(
x
.
shape
)
assert
len
(
input_shape
)
>=
2
,
\
"The x must be at least 2-dimensional, "
\
"but received Input x's dimensional: %s.
\n
"
%
\
len
(
input_shape
)
assert
(
input_shape
[
-
1
]
==
input_shape
[
-
2
]),
\
"Expect squared input,"
\
"but received %s by %s matrix.
\n
"
\
%
(
input_shape
[
-
2
],
input_shape
[
-
1
])
\
helper
=
LayerHelper
(
'slogdeterminant'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'slogdeterminant'
,
inputs
=
{
'Input'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]})
return
out
def
svd
(
x
,
full_matrices
=
False
,
name
=
None
):
r
"""
Computes the singular value decomposition of one matrix or a batch of regular matrices.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录