提交 f1d4dc6d 编写于 作者: H haoyuying

revise style transfer second time

上级 991f2a7e
......@@ -187,8 +187,7 @@ class Inspiration(nn.Layer):
return x
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.C) + ')'
return self.__class__.__name__ + '(' + 'N x ' + str(self.C) + ')'
class Vgg16(nn.Layer):
......@@ -282,6 +281,7 @@ class MSGNet(nn.Layer):
block = Bottleneck
upblock = UpBottleneck
expansion = 4
model = []
model1 = [
ConvLayer(input_nc, 64, kernel_size=7, stride=1),
......@@ -290,14 +290,12 @@ class MSGNet(nn.Layer):
block(64, 32, 2, 1, norm_layer),
block(32 * expansion, ngf, 2, 1, norm_layer)
]
self.model1 = nn.Sequential(*tuple(model1))
model = []
model += model1
self.ins = Inspiration(ngf * expansion)
model.append(self.ins)
for i in range(n_blocks):
model += [block(ngf * expansion, ngf, 1, None, norm_layer)]
......@@ -308,6 +306,7 @@ class MSGNet(nn.Layer):
nn.ReLU(),
ConvLayer(16 * expansion, output_nc, kernel_size=7, stride=1)
]
model = tuple(model)
self.model = nn.Sequential(*model)
......@@ -330,7 +329,6 @@ class MSGNet(nn.Layer):
model_dict[key] = paddle.ones(shape=model_dict[key].shape, dtype='float32')
self.set_dict(model_dict)
print("load pretrained checkpoint success")
self._vgg = None
def transform(self, path: str):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册