默认初始化效果很差
Created by: shengkelong
默认初始化效果(效果极差
自己抄的初始化(效果一般) `class relu(fluid.dygraph.Layer): def init(self): super(relu,self).init() def forward(self, x): prelu = fluid.dygraph.PRelu("all") x = prelu(x) return x
class upsample(fluid.dygraph.Layer):
def init(self, channel=64, scale = 2):
super(upsample, self).init()
self.scale = scale
self.n1 = Conv2D(channel, channelself.scaleself.scale, 3,padding=1)
def forward(self, x):
x = self.n1(x)
x = fluid.layers.pixel_shuffle(x,upscale_factor=self.scale)
return x
class MSRB(fluid.dygraph.Layer):
def init(self,channels=64):
super(MSRB, self).init()
self.conv_3_1 = Conv2D (channels,channels,3,padding = 1,act='relu')
self.conv_3_2 = Conv2D (channels2,channels2,3,padding = 1,act='relu')
self.conv_5_1 = Conv2D (channels,channels,5,padding = 2,act='relu')
self.conv_5_2 = Conv2D (channels2,channels2,5,padding = 2,act='relu')
self.confusion = Conv2D (channels*4,channels,1,padding = 0)
def forward(self, x):
input_1 = x
output_3_1 = self.conv_3_1(input_1)
output_5_1 = self.conv_5_1(input_1)
input_2 = fluid.layers.concat([output_3_1, output_5_1], axis=1)
output_3_2 = self.conv_3_2(input_2)
output_5_2 = self.conv_5_2(input_2)
input_3 = fluid.layers.concat([output_3_2, output_5_2], axis=1)
output = self.confusion(input_3)
output += x
return output
class MSRN(fluid.dygraph.Layer):
def init(self):
super(MSRN, self).init()
self.n_blocks = 1
# define head module
modules_head = [Conv2D(3, 64, 3,padding = 1)]
# define body module
modules_body = paddle.fluid.dygraph.LayerList([MSRB()for i in range(self.n_blocks)])
# define tail module
modules_tail = [
Conv2D(64 * (self.n_blocks + 1), 64, 1),
Conv2D(64, 64, 3,padding = 1),
upsample(),
Conv2D(64, 3, 3,padding = 1)]
self.head = fluid.dygraph.Sequential(*modules_head)
self.body = fluid.dygraph.Sequential(*modules_body)
self.tail = fluid.dygraph.Sequential(*modules_tail)
def forward(self, x):
x = self.head(x)
res = x
MSRB_out = []
for i in range(self.n_blocks):
x = self.body[i](x)
MSRB_out.append(x)
MSRB_out.append(res)
res = fluid.layers.concat(MSRB_out,axis=1)
x = self.tail(res)
return x `