Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6d8771b5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6d8771b5
编写于
3月 05, 2019
作者:
K
Kaipeng Deng
提交者:
GitHub
3月 05, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15864 from heavengate/spectral_norm
Add spectral norm op
上级
caadd058
3eab9e4b
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
721 addition
and
0 deletion
+721
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/spectral_norm_op.cc
paddle/fluid/operators/spectral_norm_op.cc
+197
-0
paddle/fluid/operators/spectral_norm_op.cu
paddle/fluid/operators/spectral_norm_op.cu
+22
-0
paddle/fluid/operators/spectral_norm_op.h
paddle/fluid/operators/spectral_norm_op.h
+273
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+93
-0
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+13
-0
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
+122
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
6d8771b5
...
@@ -128,6 +128,7 @@ paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'par
...
@@ -128,6 +128,7 @@ paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'par
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '013795af319e2e86d3506741941078ee'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '013795af319e2e86d3506741941078ee'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b'))
paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3'))
paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88'))
paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '6148b6a555cbfb62fdcd030d8982c18c'))
paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '6148b6a555cbfb62fdcd030d8982c18c'))
...
...
paddle/fluid/operators/spectral_norm_op.cc
0 → 100644
浏览文件 @
6d8771b5
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
SpectralNormOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of SpectralNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"U"
),
"Input(U) of SpectralNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"V"
),
"Input(V) of SpectralNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SpectralNormOp should not be null."
);
auto
dim_weight
=
ctx
->
GetInputDim
(
"Weight"
);
auto
rank_weight
=
dim_weight
.
size
();
PADDLE_ENFORCE
(
rank_weight
>=
2
&&
rank_weight
<=
5
,
"The rank of Input(Weights) can only be 2, 3,"
"4, 5 for fc, conv1d, conv2d, conv3d layers."
);
int
dim
=
ctx
->
Attrs
().
Get
<
int
>
(
"dim"
);
int
power_iters
=
ctx
->
Attrs
().
Get
<
int
>
(
"power_iters"
);
PADDLE_ENFORCE
(
dim
==
0
||
dim
==
1
,
"Attr(dim) can only be 0 or 1"
);
PADDLE_ENFORCE
(
power_iters
>=
0
,
"Attr(power_iters) should be larger equal then 0"
);
int
h
=
dim_weight
[
dim
];
int
w
=
1
;
for
(
int
i
=
0
;
i
<
rank_weight
;
i
++
)
{
if
(
i
!=
dim
)
{
w
*=
dim_weight
[
i
];
}
}
auto
dim_u
=
ctx
->
GetInputDim
(
"U"
);
auto
dim_v
=
ctx
->
GetInputDim
(
"V"
);
PADDLE_ENFORCE_EQ
(
dim_u
[
0
],
h
,
"Input(U) dims[0] should be equal to "
"Input(Weight) dims[Attr(dim)]"
);
PADDLE_ENFORCE_EQ
(
dim_v
[
0
],
w
,
"Input(V) dims[0] should be equal to "
"the product of Input(Weight) dims except dims[Attr(dim)]"
);
ctx
->
SetOutputDim
(
"Out"
,
dim_weight
);
ctx
->
ShareLoD
(
"Weight"
,
/*->*/
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Weight"
)
->
type
(),
ctx
.
GetPlace
());
}
};
class
SpectralNormOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Weight"
,
"The input weight tensor of spectral_norm operator, "
"This can be a 2-D, 3-D, 4-D, 5-D tensor which is the "
"weights of fc, conv1d, conv2d, conv3d layer."
);
AddInput
(
"U"
,
"The weight_u tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [H, 1],"
"H is the 1st dimentions of Weight after reshape"
"corresponding by Attr(dim). As for Attr(dim) = 1"
"in conv2d layer with weight shape [M, C, K1, K2]"
"Weight will be reshape to [C, M*K1*K2], U will"
"be in shape [C, 1]."
);
AddInput
(
"V"
,
"The weight_v tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [W, 1], "
"W is the 2nd dimentions of Weight after reshape "
"corresponding by Attr(dim). As for Attr(dim) = 1 "
"in conv2d layer with weight shape [M, C, K1, K2] "
"Weight will be reshape to [C, M*K1*K2], V will "
"be in shape [M*K1*K2, 1]."
);
AddOutput
(
"Out"
,
"The output weight tensor of spectral_norm operator, "
"This tensor is in same shape with Input(Weight)."
);
AddAttr
<
int
>
(
"dim"
,
"The index of dimension which should be permuted "
"to the first before reshaping Input(Weight) to "
"matrix, it should be set as 0 if Input(Weight) is "
"the weight of fc layer, and should be set as 1 if "
"Input(Weight) is the weight of conv layer, "
"default 0."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"power_iters"
,
"number of power iterations to calculate "
"spectral norm, default 1."
)
.
SetDefault
(
1
);
AddAttr
<
float
>
(
"eps"
,
"epsilon for numerical stability in "
"calculating norms"
)
.
SetDefault
(
1e-12
);
AddComment
(
R"DOC(
This layer calculates the spectral normalization value of weight of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
tensor.
Spectral normalization stabilizes the training of critic in GANs
(Generative Adversarial Networks). This layer rescaling weight tensor
with spectral normalize value.
For spectral normalization calculations, we rescaling weight
tensor with :math:`\sigma`, while :math:`\sigma{\mathbf{W}}` is
$$\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \\frac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}$$
We calculate :math:`\sigma{\mathbf{W}}` through power iterations as
$$
\mathbf{v} = \mathbf{W}^{T} \mathbf{u}
$$
$$
\mathbf{v} = \\frac{\mathbf{v}}{\|\mathbf{v}\|_2}
$$
$$
\mathbf{u} = \mathbf{W}^{T} \mathbf{v}
$$
$$
\mathbf{u} = \\frac{\mathbf{u}}{\|\mathbf{u}\|_2}
$$
And :math:`\sigma` should be
$$\sigma{\mathbf{W}} = \mathbf{u}^{T} \mathbf{W} \mathbf{v}$$
For details of spectral normalization, please refer to paper:
`Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
)DOC"
);
}
};
class
SpectralNormOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"U"
),
"Input(U) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"V"
),
"Input(V) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
dim_x
=
ctx
->
GetInputDim
(
"Weight"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Weight"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Weight"
),
dim_x
);
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
Tensor
>
(
"Weight"
)
->
type
(),
ctx
.
GetPlace
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
spectral_norm
,
ops
::
SpectralNormOp
,
ops
::
SpectralNormOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
spectral_norm_grad
,
ops
::
SpectralNormOpGrad
);
REGISTER_OP_CPU_KERNEL
(
spectral_norm
,
ops
::
SpectralNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SpectralNormKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
spectral_norm_grad
,
ops
::
SpectralNormGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SpectralNormGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/spectral_norm_op.cu
0 → 100644
浏览文件 @
6d8771b5
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/spectral_norm_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
spectral_norm
,
ops
::
SpectralNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SpectralNormKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
spectral_norm_grad
,
ops
::
SpectralNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SpectralNormGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/spectral_norm_op.h
0 → 100644
浏览文件 @
6d8771b5
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Tensor
=
framework
::
Tensor
;
using
Array1
=
Eigen
::
DSizes
<
int64_t
,
1
>
;
using
Array2
=
Eigen
::
DSizes
<
int64_t
,
2
>
;
using
IndexPair
=
Eigen
::
IndexPair
<
int
>
;
template
<
typename
DeviceContext
,
typename
T
>
static
inline
void
TransCompute
(
const
int
rank
,
const
Tensor
&
in
,
Tensor
*
out
,
const
std
::
vector
<
int
>&
perm
,
const
DeviceContext
&
dev_ctx
)
{
if
(
rank
<=
1
||
rank
>
5
)
{
PADDLE_THROW
(
"Invalid weight rank."
);
}
switch
(
rank
)
{
case
2
:
math
::
Transpose
<
DeviceContext
,
T
,
2
>
trans2
;
trans2
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
3
:
math
::
Transpose
<
DeviceContext
,
T
,
3
>
trans3
;
trans3
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
4
:
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans4
;
trans4
(
dev_ctx
,
in
,
out
,
perm
);
break
;
case
5
:
math
::
Transpose
<
DeviceContext
,
T
,
5
>
trans5
;
trans5
(
dev_ctx
,
in
,
out
,
perm
);
break
;
default:
break
;
}
}
template
<
typename
DeviceContext
,
typename
T
>
static
inline
void
CalcMatrixSigmaAndNormWeight
(
Tensor
*
sigma
,
Tensor
*
u
,
Tensor
*
v
,
Tensor
*
weight
,
const
int
power_iters
,
const
float
eps
,
const
framework
::
ExecutionContext
&
ctx
)
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
sigma_t
=
EigenTensor
<
T
,
2
>::
From
(
*
sigma
);
auto
weight_t
=
EigenTensor
<
T
,
2
>::
From
(
*
weight
);
auto
u_t
=
EigenTensor
<
T
,
2
>::
From
(
*
u
);
auto
v_t
=
EigenTensor
<
T
,
2
>::
From
(
*
v
);
const
int
h
=
weight
->
dims
()[
0
];
const
int
w
=
weight
->
dims
()[
1
];
for
(
int
i
=
0
;
i
<
power_iters
;
i
++
)
{
// V = W^T * U / ||W^T * U||_2
blas
.
MatMul
(
*
weight
,
true
,
*
u
,
false
,
T
(
1
),
v
,
T
(
0
));
auto
v_t_norm
=
v_t
.
square
().
sum
().
sqrt
().
eval
().
reshape
(
Array1
(
1
)).
broadcast
(
Array1
(
w
));
v_t
.
device
(
place
)
=
v_t
/
(
v_t_norm
+
v_t_norm
.
constant
(
eps
));
// U = W^T * V / ||W^T * V||_2
blas
.
MatMul
(
*
weight
,
false
,
*
v
,
false
,
T
(
1
),
u
,
T
(
0
));
auto
u_t_norm
=
u_t
.
square
().
sum
().
sqrt
().
eval
().
reshape
(
Array1
(
1
)).
broadcast
(
Array1
(
h
));
u_t
.
device
(
place
)
=
u_t
/
(
u_t_norm
+
u_t_norm
.
constant
(
eps
));
}
Tensor
weight_v
;
weight_v
.
mutable_data
<
T
>
({
h
,
1
},
ctx
.
GetPlace
());
blas
.
MatMul
(
*
weight
,
false
,
*
v
,
false
,
T
(
1
),
&
weight_v
,
T
(
0
));
auto
weight_v_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_v
);
sigma_t
.
device
(
place
)
=
(
u_t
*
weight_v_t
)
.
sum
()
.
eval
()
.
reshape
(
Array2
(
1
,
1
))
.
broadcast
(
Array2
(
h
,
w
));
weight_t
.
device
(
place
)
=
weight_t
/
sigma_t
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
SpectralNormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
auto
v
=
ctx
.
Input
<
Tensor
>
(
"V"
);
auto
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
const
int
h
=
u
->
dims
()[
0
];
const
int
w
=
v
->
dims
()[
0
];
Tensor
weight_mat
;
auto
dims
=
weight
->
dims
();
const
int
rank
=
dims
.
size
();
std
::
vector
<
int
>
real_dims
;
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
perm
.
push_back
(
dim
);
real_dims
.
push_back
(
dims
[
dim
]);
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
!=
dim
)
{
perm
.
push_back
(
i
);
real_dims
.
push_back
(
dims
[
i
]);
}
}
weight_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
weight
,
&
weight_mat
,
perm
,
dev_ctx
);
}
else
{
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
real_dims
.
push_back
(
i
);
}
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
}
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
Tensor
sigma
;
sigma
.
mutable_data
<
T
>
(
weight_mat
.
dims
(),
ctx
.
GetPlace
());
Tensor
uu
,
vv
;
TensorCopySync
(
*
u
,
ctx
.
GetPlace
(),
&
uu
);
TensorCopySync
(
*
v
,
ctx
.
GetPlace
(),
&
vv
);
CalcMatrixSigmaAndNormWeight
<
DeviceContext
,
T
>
(
&
sigma
,
&
(
uu
.
Resize
({
h
,
1
})),
&
(
vv
.
Resize
({
w
,
1
})),
&
weight_mat
,
power_iters
,
eps
,
ctx
);
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
<
dim
)
{
perm
.
push_back
(
i
+
1
);
}
else
if
(
i
==
dim
)
{
perm
.
push_back
(
0
);
}
else
{
perm
.
push_back
(
i
);
}
}
out
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
weight_mat
.
Resize
(
framework
::
make_ddim
(
real_dims
)),
out
,
perm
,
dev_ctx
);
}
else
{
TensorCopySync
(
weight_mat
.
Resize
(
dims
),
ctx
.
GetPlace
(),
out
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SpectralNormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
weight
=
ctx
.
Input
<
Tensor
>
(
"Weight"
);
auto
u
=
ctx
.
Input
<
Tensor
>
(
"U"
);
auto
v
=
ctx
.
Input
<
Tensor
>
(
"V"
);
auto
out_grad
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
weight_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Weight"
));
int
dim
=
ctx
.
Attr
<
int
>
(
"dim"
);
int
power_iters
=
ctx
.
Attr
<
int
>
(
"power_iters"
);
float
eps
=
ctx
.
Attr
<
float
>
(
"eps"
);
const
int
h
=
u
->
dims
()[
0
];
const
int
w
=
v
->
dims
()[
0
];
Tensor
weight_mat
,
out_grad_mat
;
auto
dims
=
weight
->
dims
();
const
int
rank
=
dims
.
size
();
std
::
vector
<
int
>
real_dims
;
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
perm
.
push_back
(
dim
);
real_dims
.
push_back
(
dims
[
dim
]);
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
!=
dim
)
{
perm
.
push_back
(
i
);
real_dims
.
push_back
(
dims
[
i
]);
}
}
weight_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
out_grad_mat
.
mutable_data
<
T
>
(
framework
::
make_ddim
(
real_dims
),
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
weight
,
&
weight_mat
,
perm
,
dev_ctx
);
TransCompute
<
DeviceContext
,
T
>
(
rank
,
*
out_grad
,
&
out_grad_mat
,
perm
,
dev_ctx
);
}
else
{
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
real_dims
.
push_back
(
i
);
}
TensorCopySync
(
*
weight
,
ctx
.
GetPlace
(),
&
weight_mat
);
TensorCopySync
(
*
out_grad
,
ctx
.
GetPlace
(),
&
out_grad_mat
);
}
weight_mat
=
weight_mat
.
Resize
({
h
,
w
});
out_grad_mat
=
out_grad_mat
.
Resize
({
h
,
w
});
Tensor
sigma
;
sigma
.
mutable_data
<
T
>
(
weight_mat
.
dims
(),
ctx
.
GetPlace
());
Tensor
uu
,
vv
;
TensorCopySync
(
*
u
,
ctx
.
GetPlace
(),
&
uu
);
TensorCopySync
(
*
v
,
ctx
.
GetPlace
(),
&
vv
);
CalcMatrixSigmaAndNormWeight
<
DeviceContext
,
T
>
(
&
sigma
,
&
(
uu
.
Resize
({
h
,
1
})),
&
(
vv
.
Resize
({
w
,
1
})),
&
weight_mat
,
power_iters
,
eps
,
ctx
);
Tensor
uv
;
uv
.
mutable_data
<
T
>
({
h
,
w
},
ctx
.
GetPlace
());
blas
.
MatMul
(
uu
.
Resize
({
h
,
1
}),
false
,
vv
.
Resize
({
w
,
1
}),
false
,
T
(
1
),
&
uv
,
T
(
0
));
Tensor
weight_grad_mat
;
weight_grad_mat
.
mutable_data
<
T
>
({
h
,
w
},
ctx
.
GetPlace
());
auto
weight_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_grad_mat
);
auto
weight_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
weight_mat
);
auto
out_grad_mat_t
=
EigenTensor
<
T
,
2
>::
From
(
out_grad_mat
);
auto
sigma_t
=
EigenTensor
<
T
,
2
>::
From
(
sigma
);
auto
uv_t
=
EigenTensor
<
T
,
2
>::
From
(
uv
);
weight_mat_t
.
device
(
place
)
=
weight_mat_t
.
sum
().
eval
().
reshape
(
Array2
(
1
,
1
)).
broadcast
(
Array2
(
h
,
w
));
weight_grad_mat_t
.
device
(
place
)
=
out_grad_mat_t
*
(
out_grad_mat_t
.
constant
(
1.0
)
-
uv_t
*
weight_mat_t
)
/
sigma_t
;
if
(
dim
!=
0
)
{
std
::
vector
<
int
>
perm
;
for
(
int
i
=
0
;
i
<
rank
;
i
++
)
{
if
(
i
<
dim
)
{
perm
.
push_back
(
i
+
1
);
}
else
if
(
i
==
dim
)
{
perm
.
push_back
(
0
);
}
else
{
perm
.
push_back
(
i
);
}
}
weight_grad
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
TransCompute
<
DeviceContext
,
T
>
(
rank
,
weight_grad_mat
.
Resize
(
framework
::
make_ddim
(
real_dims
)),
weight_grad
,
perm
,
dev_ctx
);
}
else
{
TensorCopySync
(
weight_grad_mat
.
Resize
(
dims
),
ctx
.
GetPlace
(),
weight_grad
);
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
6d8771b5
...
@@ -94,6 +94,7 @@ __all__ = [
...
@@ -94,6 +94,7 @@ __all__ = [
'multiplex'
,
'multiplex'
,
'layer_norm'
,
'layer_norm'
,
'group_norm'
,
'group_norm'
,
'spectral_norm'
,
'softmax_with_cross_entropy'
,
'softmax_with_cross_entropy'
,
'smooth_l1'
,
'smooth_l1'
,
'one_hot'
,
'one_hot'
,
...
@@ -3346,6 +3347,98 @@ def group_norm(input,
...
@@ -3346,6 +3347,98 @@ def group_norm(input,
return
helper
.
append_activation
(
group_norm_out
)
return
helper
.
append_activation
(
group_norm_out
)
@
templatedoc
()
def
spectral_norm
(
weight
,
dim
=
0
,
power_iters
=
1
,
eps
=
1e-12
,
name
=
None
):
"""
**Spectral Normalization Layer**
This layer calculates the spectral normalization value of weight parameters of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
Parameters. Calculations are showed as follows.
Step 1:
Generate vector U in shape of [H], and V in shape of [W].
While H is the :attr:`dim` th dimension of the input weights,
and W is the product result of remaining dimensions.
Step 2:
:attr:`power_iters` shoule be a positive interger, do following
calculations with U and V for :attr:`power_iters` rounds.
.. math::
\mathbf{v} :=
\\
frac{\mathbf{W}^{T} \mathbf{u}}{\|\mathbf{W}^{T} \mathbf{u}\|_2}
\mathbf{u} :=
\\
frac{\mathbf{W}^{T} \mathbf{v}}{\|\mathbf{W}^{T} \mathbf{v}\|_2}
Step 3:
Calculate :math:`\sigma(\mathbf{W})` and normalize weight values.
.. math::
\sigma(\mathbf{W}) = \mathbf{u}^{T} \mathbf{W} \mathbf{v}
\mathbf{W} =
\\
frac{\mathbf{W}}{\sigma(\mathbf{W})}
Refer to `Spectral Normalization <https://arxiv.org/abs/1802.05957>`_ .
Args:
weight(${weight_type}): ${weight_comment}
dim(int): ${dim_comment}
power_iters(int): ${power_iters_comment}
eps(float): ${eps_comment}
name (str): The name of this layer. It is optional.
Returns:
Variable: A tensor variable of weight parameters after spectral normalization.
Examples:
>>> weight = fluid.layers.data(name='weight', shape=[8, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.spectral_norm(weight=data, dim=1, power_iters=2)
"""
helper
=
LayerHelper
(
'spectral_norm'
,
**
locals
())
dtype
=
weight
.
dtype
# create intput and parameters
inputs
=
{
'Weight'
:
weight
}
input_shape
=
weight
.
shape
h
=
input_shape
[
dim
]
w
=
np
.
prod
(
input_shape
)
//
h
u
=
helper
.
create_parameter
(
attr
=
ParamAttr
(),
shape
=
[
h
],
dtype
=
dtype
,
default_initializer
=
Normal
(
0.
,
1.
))
u
.
stop_gradient
=
True
inputs
[
'U'
]
=
u
v
=
helper
.
create_parameter
(
attr
=
ParamAttr
(),
shape
=
[
w
],
dtype
=
dtype
,
default_initializer
=
Normal
(
0.
,
1.
))
inputs
[
'V'
]
=
v
v
.
stop_gradient
=
True
# create output
out
=
helper
.
create_variable
(
dtype
=
dtype
)
helper
.
append_op
(
type
=
"spectral_norm"
,
inputs
=
inputs
,
outputs
=
{
"Out"
:
out
,
},
attrs
=
{
"dim"
:
dim
,
"power_iters"
:
power_iters
,
"eps"
:
eps
,
})
return
out
def
conv2d_transpose
(
input
,
def
conv2d_transpose
(
input
,
num_filters
,
num_filters
,
output_size
=
None
,
output_size
=
None
,
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
6d8771b5
...
@@ -1035,6 +1035,19 @@ class TestBook(unittest.TestCase):
...
@@ -1035,6 +1035,19 @@ class TestBook(unittest.TestCase):
print
(
str
(
program
))
print
(
str
(
program
))
def
test_spectral_norm
(
self
):
program
=
Program
()
with
program_guard
(
program
):
weight
=
layers
.
data
(
name
=
'weight'
,
shape
=
[
2
,
3
,
32
,
32
],
dtype
=
"float32"
,
append_batch_size
=
False
)
out
=
layers
.
spectral_norm
(
weight
,
dim
=
1
,
power_iters
=
1
)
self
.
assertIsNotNone
(
out
)
print
(
str
(
program
))
def
test_shuffle_channel
(
self
):
def
test_shuffle_channel
(
self
):
program
=
Program
()
program
=
Program
()
with
program_guard
(
program
):
with
program_guard
(
program
):
...
...
python/paddle/fluid/tests/unittests/test_spectral_norm_op.py
0 → 100644
浏览文件 @
6d8771b5
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
division
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
from
paddle.fluid
import
core
def
spectral_norm
(
weight
,
u
,
v
,
dim
,
power_iters
,
eps
):
shape
=
weight
.
shape
weight_mat
=
weight
.
copy
()
h
=
shape
[
dim
]
w
=
np
.
prod
(
shape
)
//
h
if
dim
!=
0
:
perm
=
[
dim
]
+
[
d
for
d
in
range
(
len
(
shape
))
if
d
!=
dim
]
weight_mat
=
weight_mat
.
transpose
(
perm
)
weight_mat
=
weight_mat
.
reshape
((
h
,
w
))
u
=
u
.
reshape
((
h
,
1
))
v
=
v
.
reshape
((
w
,
1
))
for
i
in
range
(
power_iters
):
v
=
np
.
matmul
(
weight_mat
.
T
,
u
)
v_norm
=
np
.
sqrt
((
v
*
v
).
sum
())
v
=
v
/
(
v_norm
+
eps
)
u
=
np
.
matmul
(
weight_mat
,
v
)
u_norm
=
np
.
sqrt
((
u
*
u
).
sum
())
u
=
u
/
(
u_norm
+
eps
)
sigma
=
(
u
*
np
.
matmul
(
weight_mat
,
v
)).
sum
()
return
weight
/
sigma
class
TestSpectralNormOpNoGrad
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
self
.
op_type
=
'spectral_norm'
weight
=
np
.
random
.
random
(
self
.
weight_shape
).
astype
(
'float32'
)
u
=
np
.
random
.
normal
(
0.
,
1.
,
self
.
u_shape
).
astype
(
'float32'
)
v
=
np
.
random
.
normal
(
0.
,
1.
,
self
.
v_shape
).
astype
(
'float32'
)
self
.
attrs
=
{
"dim"
:
self
.
dim
,
"power_iters"
:
self
.
power_iters
,
"eps"
:
self
.
eps
,
}
self
.
inputs
=
{
"Weight"
:
weight
,
"U"
:
u
,
"V"
:
v
,
}
output
=
spectral_norm
(
weight
,
u
,
v
,
self
.
dim
,
self
.
power_iters
,
self
.
eps
)
self
.
outputs
=
{
"Out"
:
output
}
def
test_check_output
(
self
):
self
.
check_output
()
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
)
self
.
u_shape
=
(
2
,
)
self
.
v_shape
=
(
3
,
)
self
.
dim
=
0
self
.
power_iters
=
5
self
.
eps
=
1e-12
class
TestSpectralNormOpNoGrad2
(
TestSpectralNormOpNoGrad
):
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
u_shape
=
(
3
,
)
self
.
v_shape
=
(
18
,
)
self
.
dim
=
1
self
.
power_iters
=
10
self
.
eps
=
1e-12
class
TestSpectralNormOp
(
TestSpectralNormOpNoGrad
):
def
test_check_grad_ignore_uv
(
self
):
self
.
check_grad
(
[
'Weight'
],
'Out'
,
no_grad_set
=
set
([
"U"
,
"V"
]),
max_relative_error
=
0.1
)
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
)
self
.
u_shape
=
(
2
,
)
self
.
v_shape
=
(
3
,
)
self
.
dim
=
0
self
.
power_iters
=
0
self
.
eps
=
1e-12
class
TestSpectralNormOp2
(
TestSpectralNormOp
):
def
initTestCase
(
self
):
self
.
weight_shape
=
(
2
,
3
,
3
,
3
)
self
.
u_shape
=
(
3
,
)
self
.
v_shape
=
(
18
,
)
self
.
dim
=
1
self
.
power_iters
=
0
self
.
eps
=
1e-12
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录