使用默认初始化损失过高,收敛很慢
Created by: shengkelong
#MSRB
class MSRB(fluid.dygraph.Layer):
def __init__(self,channels=64):
super(MSRB, self).__init__()
self.conv_3_1 = Conv2D (channels,channels,3,padding = 1,act='relu',param_attr=fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001)))
self.conv_3_2 = Conv2D (channels*2,channels*2,3,padding = 1,act='relu',param_attr=fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001)))
self.conv_5_1 = Conv2D (channels,channels,5,padding = 2,act='relu',param_attr=fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001)))
self.conv_5_2 = Conv2D (channels*2,channels*2,5,padding = 2,act='relu',param_attr=fluid.ParamAttr(initializer=fluid.initializer.NormalInitializer(scale=0.001)))
self.confusion = Conv2D (channels*4,channels,1,padding = 0)
def forward(self, x):
input_1 = x
output_3_1 = self.conv_3_1(input_1)
output_5_1 = self.conv_5_1(input_1)
input_2 = fluid.layers.concat([output_3_1, output_5_1], axis=1)
output_3_2 = self.conv_3_2(input_2)
output_5_2 = self.conv_5_2(input_2)
input_3 = fluid.layers.concat([output_3_2, output_5_2], axis=1)
output = self.confusion(input_3)
output += x
return output
#MSRN
class MSRN(fluid.dygraph.Layer):
def __init__(self):
# define head module
super(MSRN, self).__init__()
self.n_blocks = 4
modules_head = [Conv2D(3, 64, 3,padding = 1)]
# define body module
modules_body = paddle.fluid.dygraph.LayerList([MSRB()for i in range(self.n_blocks)])
# define tail module
modules_tail = [
Conv2D(64 * (self.n_blocks + 1), 64, 1),
Conv2D(64, 64, 3,padding = 1),
upsample(),
Conv2D(64, 3, 3,padding = 1)]
self.head = fluid.dygraph.Sequential(*modules_head)
self.body = fluid.dygraph.Sequential(*modules_body)
self.tail = fluid.dygraph.Sequential(*modules_tail)
def forward(self, x):
x = self.head(x)
res = x
MSRB_out = []
for i in range(self.n_blocks):
x = self.body[i](x)
MSRB_out.append(x)
MSRB_out.append(res)
res = fluid.layers.concat(MSRB_out,axis=1)
x = self.tail(res)
return x
复现MSRN超分辨网络,效果不佳,psnr只能到35,一旦使用默认初始化损失会直逼1e10很难收敛,不知是什么原因