提交 65d375a0 编写于 作者: D dengkaipeng

fix format. test=develop

上级 82d51434
......@@ -94,7 +94,7 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
"W is the 2nd dimentions of Weight after reshape"
"corresponding by Attr(dim). As for Attr(dim) = 1"
"in conv2d layer with weight shape [M, C, K1, K2]"
"Weight will be reshape to [C, M*K1*Kw], V will"
"Weight will be reshape to [C, M*K1*K2], V will"
"be in shape [M*K1*K2, 1].");
AddOutput("Out",
"The output weight tensor of spectral_norm operator, "
......@@ -105,7 +105,7 @@ class SpectralNormOpMaker : public framework::OpProtoAndCheckerMaker {
"it should be set as 0 if Input(Weight) is the"
"weight of fc layer, and should be set as 1 if"
"Input(Weight) is the weight of conv layer,"
"default is 0."
"default is 0.")
.SetDefault(0);
AddAttr<int>("power_iters",
"number of power iterations to calculate"
......
......@@ -73,13 +73,13 @@ static inline void CalcMatrixSigmaAndNormWeight(
const int w = weight->dims()[1];
for (int i = 0; i < power_iters; i++) {
// V = W^T * U / ||W^T * U||_2
// V = W^T * U / ||W^T * U||_2
blas.MatMul(*weight, true, *u, false, T(1), v, T(0));
auto v_t_norm =
v_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
Array1(w));
v_t.device(place) = v_t / (v_t_norm + v_t_norm.constant(eps));
// U = W^T * V / ||W^T * V||_2
// U = W^T * V / ||W^T * V||_2
blas.MatMul(*weight, false, *v, false, T(1), u, T(0));
auto u_t_norm =
u_t.square().sum().sqrt().eval().reshape(Array1(1)).broadcast(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册