Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
63d322f0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
63d322f0
编写于
2月 21, 2019
作者:
D
dengkaipeng
提交者:
ceci3
3月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix attr dim calc. test=develop
上级
ca1502c7
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
238 addition
and
43 deletion
+238
-43
paddle/fluid/operators/spectral_norm_op.cc
paddle/fluid/operators/spectral_norm_op.cc
+21
-6
paddle/fluid/operators/spectral_norm_op.h
paddle/fluid/operators/spectral_norm_op.h
+126
-25
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+75
-0
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
+16
-12
未找到文件。
paddle/fluid/operators/spectral_norm_op.cc
浏览文件 @
63d322f0
...
@@ -33,19 +33,34 @@ class SpectralNormOp : public framework::OperatorWithKernel {
...
@@ -33,19 +33,34 @@ class SpectralNormOp : public framework::OperatorWithKernel {
"Output(Out) of SpectralNormOp should not be null."
);
"Output(Out) of SpectralNormOp should not be null."
);
auto
dim_weight
=
ctx
->
GetInputDim
(
"Weight"
);
auto
dim_weight
=
ctx
->
GetInputDim
(
"Weight"
);
auto
weight_dimsize
=
dim_weight
.
size
();
auto
rank_weight
=
dim_weight
.
size
();
PADDLE_ENFORCE
(
weight_dimsize
>=
2
&&
weight_dimsize
<=
5
,
PADDLE_ENFORCE
(
rank_weight
>=
2
&&
rank_weight
<=
5
,
"The
size of dims
of Input(Weights) can only be 2, 3,"
"The
rank
of Input(Weights) can only be 2, 3,"
"4, 5 for fc, conv1d, conv2d, conv3d layers."
);
"4, 5 for fc, conv1d, conv2d, conv3d layers."
);
int
dim
=
ctx
->
Attrs
().
Get
<
int
>
(
"dim"
);
int
dim
=
ctx
->
Attrs
().
Get
<
int
>
(
"dim"
);
int
power_iters
=
ctx
->
Attrs
().
Get
<
int
>
(
"power_iters"
);
int
power_iters
=
ctx
->
Attrs
().
Get
<
int
>
(
"power_iters"
);
PADDLE_ENFORCE
(
dim
>=
0
&&
dim
<
weight_dimsize
-
1
,
PADDLE_ENFORCE
(
dim
==
0
||
dim
==
1
,
"Attr(dim) can only be 0 or 1"
);
"Attr(dim) should be larger equal 0 and less then the"
"size of dims of Input(Weights) - 1,"
);
PADDLE_ENFORCE
(
power_iters
>=
0
,
PADDLE_ENFORCE
(
power_iters
>=
0
,
"Attr(power_iters) should be larger equal then 0"
);
"Attr(power_iters) should be larger equal then 0"
);
int
h
=
dim_weight
[
dim
];
int
w
=
1
;
for
(
int
i
=
0
;
i
<
rank_weight
;
i
++
)
{
if
(
i
!=
dim
)
{
w
*=
dim_weight
[
i
];
}
}
auto
dim_u
=
ctx
->
GetInputDim
(
"U"
);
auto
dim_v
=
ctx
->
GetInputDim
(
"V"
);
PADDLE_ENFORCE_EQ
(
dim_u
[
0
],
h
,
"Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]"
);
PADDLE_ENFORCE_EQ
(
dim_v
[
0
],
w
,
"Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]"
);
ctx
->
SetOutputDim
(
"Out"
,
dim_weight
);
ctx
->
SetOutputDim
(
"Out"
,
dim_weight
);
ctx
->
ShareLoD
(
"Weight"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Weight"
,
/*->*/
"Out"
);
}
}
...
...
paddle/fluid/operators/spectral_norm_op.h
浏览文件 @
63d322f0
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
...
@@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes<int64_t, 1>;
...
@@ -27,17 +28,33 @@ using Array1 = Eigen::DSizes<int64_t, 1>;
using
Array2
=
Eigen
::
DSizes
<
int64_t
,
2
>
;
using
Array2
=
Eigen
::
DSizes
<
int64_t
,
2
>
;
using
IndexPair
=
Eigen
::
IndexPair
<
int
>
;
using
IndexPair
=
Eigen
::
IndexPair
<
int
>
;
static
inline
void
CalcMatrixShape
(
const
Tensor
&
weight
,
const
int
dim
,
int
*
h
,
template
<
typename
DeviceContext
,
typename
T
>
int
*
w
)
{
static
inline
void
TransCompute
(
const
int
rank
,
const
Tensor
&
in
,
Tensor
*
out
,
auto
weight_dims
=
weight
.
dims
();
const
std
::
vector
<
int
>&
perm
,
*
h
=
1
;
const
DeviceContext
&
dev_ctx
)
{
*
w
=
1
;
if
(
rank
<=
1
||
rank
>
5
)
{
for
(
int
i
=
0
;
i
<
weight_dims
.
size
();
i
++
)
{
PADDLE_THROW
(
"Invalid weight rank."
);
if
(
i
<=
dim
)
{
}
*
h
*=
weight_dims
[
i
];
}
else
{
switch
(
rank
)
{
*
w
*=
weight_dims
[
i
];
case
2
:
}
math
::
Transpose
<
DeviceContext
,
T
,
2
>
trans2
;
trans2
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
3
:
math
::
Transpose
<
DeviceContext
,
T
,
3
>
trans3
;
trans3
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
4
:
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans4
;
trans4
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
5
:
math
::
Transpose
<
DeviceContext
,
T
,
5
>
trans5
;
trans5
(
dev_ctx
,
in
,
out
,
perm
);
break
;
default:
break
;
}
}
}
}
...
@@ -83,6 +100,7 @@ template <typename DeviceContext, typename T>
...
@@ -83,6 +100,7 @@ template <typename DeviceContext, typename T>
class
SpectralNormKernel
:
public
framework
::
OpKernel
<
T
>
{
class
SpectralNormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
auto
v
=
ctx
.
Input
<
Tensor
>
(
"V"
);
auto
v
=
ctx
.
Input
<
Tensor
>
(
"V"
);
...
@@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel<T> {
...
@@ -92,10 +110,32 @@ class SpectralNormKernel : public framework::OpKernel<T> {
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
const
int
h
=
u
->
dims
()[
0
];
const
int
w
=
v
->
dims
()[
0
];
Tensor
weight_mat
;
Tensor
weight_mat
;
int
h
,
w
;
auto
dims
=
weight
->
dims
();
CalcMatrixShape
(
*
weight
,
dim
,
&
h
,
&
w
);
const
int
rank
=
dims
.
size
();
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
std
::
vector
<
int
>
real_dims
;
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
perm
.
push_back
(
dim
);
real_dims
.
push_back
(
dims
[
dim
]);
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
!=
dim
)
{
perm
.
push_back
(
i
);
real_dims
.
push_back
(
dims
[
i
]);
}
}
weight_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
weight
,
&
weight_mat
,
perm
,
dev_ctx
);
}
else
{
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
real_dims
.
push_back
(
i
);
}
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
}
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
Tensor
sigma
;
Tensor
sigma
;
...
@@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel<T> {
...
@@ -106,7 +146,25 @@ class SpectralNormKernel : public framework::OpKernel<T> {
CalcMatrixSigmaAndNormWeight
<
DeviceContext
,
T
>
(
CalcMatrixSigmaAndNormWeight
<
DeviceContext
,
T
>
(
&
sigma
,
&
(
uu
.
Resize
({
h
,
1
})),
&
(
vv
.
Resize
({
w
,
1
})),
&
weight_mat
,
&
sigma
,
&
(
uu
.
Resize
({
h
,
1
})),
&
(
vv
.
Resize
({
w
,
1
})),
&
weight_mat
,
power_iters
,
eps
,
ctx
);
power_iters
,
eps
,
ctx
);
TensorCopySync
(
weight_mat
.
Resize
(
out
->
dims
()),
ctx
.
GetPlace
(),
out
);
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
<
dim
)
{
perm
.
push_back
(
i
+
1
);
}
else
if
(
i
==
dim
)
{
perm
.
push_back
(
0
);
}
else
{
perm
.
push_back
(
i
);
}
}
out
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
weight_mat
.
Resize
(
framework
::
make_ddim
(
real_dims
)),
out
,
perm
,
dev_ctx
);
}
else
{
TensorCopySync
(
weight_mat
.
Resize
(
dims
),
ctx
.
GetPlace
(),
out
);
}
}
}
};
};
...
@@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
...
@@ -115,6 +173,7 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
...
@@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
...
@@ -126,11 +185,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
const
int
h
=
u
->
dims
()[
0
];
const
int
w
=
v
->
dims
()[
0
];
Tensor
weight_mat
,
out_grad_mat
;
Tensor
weight_mat
,
out_grad_mat
;
int
h
,
w
;
auto
dims
=
weight
->
dims
();
CalcMatrixShape
(
*
weight
,
dim
,
&
h
,
&
w
);
const
int
rank
=
dims
.
size
();
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
std
::
vector
<
int
>
real_dims
;
TensorCopySync
(
*
out_grad
,
ctx
.
GetPlace
(),
&
out_grad_mat
);
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
perm
.
push_back
(
dim
);
real_dims
.
push_back
(
dims
[
dim
]);
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
!=
dim
)
{
perm
.
push_back
(
i
);
real_dims
.
push_back
(
dims
[
i
]);
}
}
weight_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
out_grad_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
weight
,
&
weight_mat
,
perm
,
dev_ctx
);
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
out_grad
,
&
out_grad_mat
,
perm
,
dev_ctx
);
}
else
{
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
real_dims
.
push_back
(
i
);
}
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
TensorCopySync
(
*
out_grad
,
ctx
.
GetPlace
(),
&
out_grad_mat
);
}
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
out_grad_mat
=
out_grad_mat
.
Resize
({
h
,
w
});
out_grad_mat
=
out_grad_mat
.
Resize
({
h
,
w
});
...
@@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
...
@@ -148,21 +233,37 @@ class SpectralNormGradKernel : public framework::OpKernel<T> {
blas
.
MatMul
(
uu
.
Resize
({
h
,
1
}),
false
,
vv
.
Resize
({
w
,
1
}),
false
,
T
(
1
),
&
uv
,
blas
.
MatMul
(
uu
.
Resize
({
h
,
1
}),
false
,
vv
.
Resize
({
w
,
1
}),
false
,
T
(
1
),
&
uv
,
T
(
0
));
T
(
0
));
Tensor
weight_grad_mat
,
ones
;
Tensor
weight_grad_mat
;
weight_grad_mat
.
mutable_data
<
T
>
({
h
,
w
},
ctx
.
GetPlace
());
weight_grad_mat
.
mutable_data
<
T
>
({
h
,
w
},
ctx
.
GetPlace
());
ones
.
mutable_data
<
T
>
({
h
,
w
},
ctx
.
GetPlace
());
auto
weight_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_grad_mat
);
auto
weight_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_grad_mat
);
auto
weight_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_mat
);
auto
weight_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_mat
);
auto
out_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
out_grad_mat
);
auto
out_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
out_grad_mat
);
auto
sigma_t
=
EigenTensor
<
T
,
2
>::
From
(
sigma
);
auto
sigma_t
=
EigenTensor
<
T
,
2
>::
From
(
sigma
);
auto
uv_t
=
EigenTensor
<
T
,
2
>::
From
(
uv
);
auto
uv_t
=
EigenTensor
<
T
,
2
>::
From
(
uv
);
auto
ones_t
=
EigenTensor
<
T
,
2
>::
From
(
ones
).
setConstant
((
T
)
1
);
weight_mat_t
.
device
(
place
)
=
weight_mat_t
.
device
(
place
)
=
weight_mat_t
.
sum
().
eval
().
reshape
(
Array2
(
1
,
1
)).
broadcast
(
Array2
(
h
,
w
));
weight_mat_t
.
sum
().
eval
().
reshape
(
Array2
(
1
,
1
)).
broadcast
(
Array2
(
h
,
w
));
weight_grad_mat_t
.
device
(
place
)
=
weight_grad_mat_t
.
device
(
place
)
=
out_grad_mat_t
*
(
ones_t
-
uv_t
*
weight_mat_t
)
/
sigma_t
;
out_grad_mat_t
*
(
out_grad_mat_t
.
constant
(
1.0
)
-
uv_t
*
weight_mat_t
)
/
TensorCopySync
(
weight_grad_mat
.
Resize
(
weight_grad
->
dims
()),
ctx
.
GetPlace
(),
sigma_t
;
weight_grad
);
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
<
dim
)
{
perm
.
push_back
(
i
+
1
);
}
else
if
(
i
==
dim
)
{
perm
.
push_back
(
0
);
}
else
{
perm
.
push_back
(
i
);
}
}
weight_grad
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
weight_grad_mat
.
Resize
(
framework
::
make_ddim
(
real_dims
)),
weight_grad
,
perm
,
dev_ctx
);
}
else
{
TensorCopySync
(
weight_grad_mat
.
Resize
(
dims
),
ctx
.
GetPlace
(),
weight_grad
);
}
}
}
};
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
63d322f0
...
@@ -94,6 +94,7 @@ __all__ = [
...
@@ -94,6 +94,7 @@ __all__ = [
'multiplex'
,
'multiplex'
,
'layer_norm'
,
'layer_norm'
,
'group_norm'
,
'group_norm'
,
'spectral_norm'
,
'softmax_with_cross_entropy'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'smooth_l1'
,
'one_hot'
,
'one_hot'
,
...
@@ -3347,6 +3348,80 @@ def group_norm(input,
...
@@ -3347,6 +3348,80 @@ def group_norm(input,
return
helper
.
append_activation
(
group_norm_out
)
return
helper
.
append_activation
(
group_norm_out
)
@
templatedoc
()
def
spectral_norm
(
weight
,
dim
=
0
,
power_iters
=
1
,
eps
=
1e-12
,
u_attr
=
None
,
v_attr
=
None
,
name
=
None
):
"""
**Spectral Normalization Layer**
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(${weight_type}): ${weight_comment}
dim(${dim_type}): ${dim_comment}
eps(${eps_type}): ${eps_comment}
u_attr(ParamAttr|None): The parameter attribute for vector u in
spectral calculatings, set None to use default attribute, which
generates random values in normal distribution N(0, 1). Default: None.
v_attr(ParamAttr|None): The parameter attribute for vector v in
spectral calculatings, set None to use default attribute, which
generates random values in normal distribution N(0, 1). Default: None.
name (str): The name of this layer. It is optional.
Returns:
Variable: A tensor variable of weight after spetral normalization.
Examples:
>>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2)
"""
helper
=
LayerHelper
(
'spectral_norm'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
# create intput and parameters
inputs
=
{
'Weight'
:
weight
}
input_shape
=
input
.
shape
if
data_layout
!=
'NCHW'
:
raise
ValueError
(
"unsupported data layout:"
+
data_layout
)
param_shape
=
[
input_shape
[
1
]]
if
param_attr
:
scale
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
param_shape
,
dtype
=
dtype
,
default_initializer
=
Constant
(
1.0
))
inputs
[
'Scale'
]
=
scale
if
bias_attr
:
bias
=
helper
.
create_parameter
(
attr
=
helper
.
bias_attr
,
shape
=
param_shape
,
dtype
=
dtype
,
is_bias
=
True
)
inputs
[
'Bias'
]
=
bias
# create output
mean_out
=
helper
.
create_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
variance_out
=
helper
.
create_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
group_norm_out
=
helper
.
create_variable
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
"group_norm"
,
inputs
=
inputs
,
outputs
=
{
"Y"
:
group_norm_out
,
"Mean"
:
mean_out
,
"Variance"
:
variance_out
,
},
attrs
=
{
"epsilon"
:
epsilon
,
"groups"
:
groups
})
return
helper
.
append_activation
(
group_norm_out
)
def
conv2d_transpose
(
input
,
def
conv2d_transpose
(
input
,
num_filters
,
num_filters
,
output_size
=
None
,
output_size
=
None
,
...
...
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
浏览文件 @
63d322f0
...
@@ -22,13 +22,17 @@ from paddle.fluid import core
...
@@ -22,13 +22,17 @@ from paddle.fluid import core
def
spectral_norm
(
weight
,
u
,
v
,
dim
,
power_iters
,
eps
):
def
spectral_norm
(
weight
,
u
,
v
,
dim
,
power_iters
,
eps
):
h
=
w
=
1
shape
=
weight
.
shape
for
i
,
d
in
enumerate
(
weight
.
shape
):
weight_mat
=
weight
.
copy
()
if
i
<=
dim
:
h
=
shape
[
dim
]
h
*=
d
w
=
np
.
prod
(
shape
)
//
h
else
:
if
dim
!=
0
:
w
*=
d
perm
=
[
dim
]
+
[
d
for
d
in
range
(
len
(
shape
))
if
d
!=
dim
]
weight_mat
=
weight
.
reshape
((
h
,
w
))
weight_mat
=
weight_mat
.
transpose
(
perm
)
real_shape
=
weight_mat
.
shape
else
:
real_shape
=
shape
weight_mat
=
weight_mat
.
reshape
((
h
,
w
))
u
=
u
.
reshape
((
h
,
1
))
u
=
u
.
reshape
((
h
,
1
))
v
=
v
.
reshape
((
w
,
1
))
v
=
v
.
reshape
((
w
,
1
))
...
@@ -41,7 +45,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps):
...
@@ -41,7 +45,7 @@ def spectral_norm(weight, u, v, dim, power_iters, eps):
u
=
u
/
(
u_norm
+
eps
)
u
=
u
/
(
u_norm
+
eps
)
sigma
=
(
u
*
np
.
matmul
(
weight_mat
,
v
)).
sum
()
sigma
=
(
u
*
np
.
matmul
(
weight_mat
,
v
)).
sum
()
return
(
weight_mat
/
sigma
).
reshape
(
weight
.
shape
)
return
weight
/
sigma
class
TestSpectralNormOpNoGrad
(
OpTest
):
class
TestSpectralNormOpNoGrad
(
OpTest
):
...
@@ -83,8 +87,8 @@ class TestSpectralNormOpNoGrad(OpTest):
...
@@ -83,8 +87,8 @@ class TestSpectralNormOpNoGrad(OpTest):
class
TestSpectralNormOpNoGrad2
(
TestSpectralNormOpNoGrad
):
class
TestSpectralNormOpNoGrad2
(
TestSpectralNormOpNoGrad
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
u_shape
=
(
6
,
)
self
.
u_shape
=
(
3
,
)
self
.
v_shape
=
(
9
,
)
self
.
v_shape
=
(
18
,
)
self
.
dim
=
1
self
.
dim
=
1
self
.
power_iters
=
10
self
.
power_iters
=
10
self
.
eps
=
1e-12
self
.
eps
=
1e-12
...
@@ -110,8 +114,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad):
...
@@ -110,8 +114,8 @@ class TestSpectralNormOp(TestSpectralNormOpNoGrad):
class
TestSpectralNormOp2
(
TestSpectralNormOp
):
class
TestSpectralNormOp2
(
TestSpectralNormOp
):
def
initTestCase
(
self
):
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
u_shape
=
(
6
,
)
self
.
u_shape
=
(
3
,
)
self
.
v_shape
=
(
9
,
)
self
.
v_shape
=
(
18
,
)
self
.
dim
=
1
self
.
dim
=
1
self
.
power_iters
=
0
self
.
power_iters
=
0
self
.
eps
=
1e-12
self
.
eps
=
1e-12
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录