提交 f1d4dc6d 编写于 作者: H haoyuying

revise style transfer second time

上级 991f2a7e
...@@ -187,8 +187,7 @@ class Inspiration(nn.Layer): ...@@ -187,8 +187,7 @@ class Inspiration(nn.Layer):
return x return x
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(' \ return self.__class__.__name__ + '(' + 'N x ' + str(self.C) + ')'
+ 'N x ' + str(self.C) + ')'
class Vgg16(nn.Layer): class Vgg16(nn.Layer):
...@@ -282,6 +281,7 @@ class MSGNet(nn.Layer): ...@@ -282,6 +281,7 @@ class MSGNet(nn.Layer):
block = Bottleneck block = Bottleneck
upblock = UpBottleneck upblock = UpBottleneck
expansion = 4 expansion = 4
model = []
model1 = [ model1 = [
ConvLayer(input_nc, 64, kernel_size=7, stride=1), ConvLayer(input_nc, 64, kernel_size=7, stride=1),
...@@ -290,14 +290,12 @@ class MSGNet(nn.Layer): ...@@ -290,14 +290,12 @@ class MSGNet(nn.Layer):
block(64, 32, 2, 1, norm_layer), block(64, 32, 2, 1, norm_layer),
block(32 * expansion, ngf, 2, 1, norm_layer) block(32 * expansion, ngf, 2, 1, norm_layer)
] ]
self.model1 = nn.Sequential(*tuple(model1)) self.model1 = nn.Sequential(*tuple(model1))
model = []
model += model1 model += model1
self.ins = Inspiration(ngf * expansion) self.ins = Inspiration(ngf * expansion)
model.append(self.ins) model.append(self.ins)
for i in range(n_blocks): for i in range(n_blocks):
model += [block(ngf * expansion, ngf, 1, None, norm_layer)] model += [block(ngf * expansion, ngf, 1, None, norm_layer)]
...@@ -308,6 +306,7 @@ class MSGNet(nn.Layer): ...@@ -308,6 +306,7 @@ class MSGNet(nn.Layer):
nn.ReLU(), nn.ReLU(),
ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1) ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1)
] ]
model = tuple(model) model = tuple(model)
self.model = nn.Sequential(*model) self.model = nn.Sequential(*model)
...@@ -330,7 +329,6 @@ class MSGNet(nn.Layer): ...@@ -330,7 +329,6 @@ class MSGNet(nn.Layer):
model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32') model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32')
self.set_dict(model_dict) self.set_dict(model_dict)
print("load pretrained checkpoint success") print("load pretrained checkpoint success")
self._vgg = None self._vgg = None
def transform(self, path: str): def transform(self, path: str):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册