提交 9c47f36d 编写于 作者: D dengkaipeng 提交者: ceci3

fix spectral_norm doc. test=develop

上级 91f85315
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -84,20 +84,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,20 +84,28 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
"The weight_u tensor of spectral_norm operator, " "The weight_u tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [H, 1]," "This can be a 1-D tensor in shape [H, 1],"
"H is the 1st dimentions of Weight after reshape" "H is the 1st dimentions of Weight after reshape"
"corresponding by Attr(dim)."); "corresponding by Attr(dim). As for Attr(dim) = 1"
"in conv2d layer with weight shape [M, C, K1, K2]"
"Weight will be reshape to [C, M*K1*Kw], U will"
"be in shape [C, 1].");
AddInput("V", AddInput("V",
"The weight_u tensor of spectral_norm operator, " "The weight_v tensor of spectral_norm operator, "
"This can be a 1-D tensor in shape [W, 1]," "This can be a 1-D tensor in shape [W, 1],"
"W is the 2nd dimentions of Weight after reshape" "W is the 2nd dimentions of Weight after reshape"
"corresponding by Attr(dim)."); "corresponding by Attr(dim). As for Attr(dim) = 1"
"in conv2d layer with weight shape [M, C, K1, K2]"
"Weight will be reshape to [C, M*K1*Kw], V will"
"be in shape [M*K1*K2, 1].");
AddOutput("Out", AddOutput("Out",
"The output weight tensor of spectral_norm operator, " "The output weight tensor of spectral_norm operator, "
"This tensor is in same shape with Input(Weight)."); "This tensor is in same shape with Input(Weight).");
AddAttr<int>("dim", AddAttr<int>("dim",
"dimension corresponding to number of outputs," "dimension corresponding to number of outputs,"
"default 0 for fc layer, and 1 for conv1d, conv2d, conv3d" "it should be set as 0 if Input(Weight) is the"
"layers") "weight of fc layer, and should be set as 1 if"
"Input(Weight) is the weight of conv layer,"
"default is 0."
.SetDefault(0); .SetDefault(0);
AddAttr<int>("power_iters", AddAttr<int>("power_iters",
"number of power iterations to calculate" "number of power iterations to calculate"
...@@ -109,13 +117,13 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -109,13 +117,13 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(1e-12); .SetDefault(1e-12);
AddComment(R"DOC( AddComment(R"DOC(
This layer calculate the spectral normalize value of weight of This layer calculates the spectral normalize value of weight of
fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D fc, conv1d, conv2d, conv3d layers which should be 2-D, 3-D, 4-D, 5-D
tensor. tensor.
Spectral normalization stabilizes the training of critis in GANs Spectral normalization stabilizes the training of critic in GANs
(Generative Adversarial Networks). This layers rescaling weight tensor (Generative Adversarial Networks). This layer rescaling weight tensor
wiht spectral normalize value. with spectral normalize value.
For spectral normalization calculations, we rescaling weight For spectral normalization calculations, we rescaling weight
tensor with \sigma, while \sigma{\mathbf{W}} is tensor with \sigma, while \sigma{\mathbf{W}} is
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -73,11 +73,13 @@ static inline void CalcMatrixSigmaAndNormWeight( ...@@ -73,11 +73,13 @@ static inline void CalcMatrixSigmaAndNormWeight(
const int w = weight->dims()[1]; const int w = weight->dims()[1];
for (int i = 0; i < power_iters; i++) { for (int i = 0; i < power_iters; i++) {
// V = W^T * U / ||W^T * U||_2
blas.MatMul(*weight, true, *u, false, T(1), v, T(0)); blas.MatMul(*weight, true, *u, false, T(1), v, T(0));
auto v_t_norm = auto v_t_norm =
v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(w)); Array1(w));
v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps)); v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps));
// U = W^T * V / ||W^T * V||_2
blas.MatMul(*weight, false, *v, false, T(1), u, T(0)); blas.MatMul(*weight, false, *v, false, T(1), u, T(0));
auto u_t_norm = auto u_t_norm =
u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast( u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册