model.py 5.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

X
xiaoting 已提交
15 16 17 18 19 20 21 22
from layers import *
import paddle.fluid as fluid


class build_resnet_block(fluid.dygraph.Layer):
    def __init__(self,
        dim,
        use_bias=False):
23
        super(build_resnet_block,self).__init__()
X
xiaoting 已提交
24

25 26
        self.conv0 = conv2d(
            num_channels=dim,
X
xiaoting 已提交
27 28 29 30 31
            num_filters=dim,
            filter_size=3,
            stride=1,
            stddev=0.02,
            use_bias=False)
32 33
        self.conv1 = conv2d(
            num_channels=dim,
X
xiaoting 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
            num_filters=dim,
            filter_size=3,
            stride=1,
            stddev=0.02,
            relu=False,
            use_bias=False)
        self.dim = dim
    def forward(self,inputs):
        out_res = fluid.layers.pad2d(inputs, [1, 1, 1, 1], mode="reflect")
        out_res = self.conv0(out_res)
        
        #if self.use_dropout:
        #    out_res = fluid.layers.dropout(out_res,dropout_prod=0.5)
        out_res = fluid.layers.pad2d(out_res, [1, 1, 1, 1], mode="reflect")
        out_res = self.conv1(out_res)
        return out_res + inputs

51

X
xiaoting 已提交
52
class build_generator_resnet_9blocks(fluid.dygraph.Layer):
53 54
    def __init__ (self, input_channel):
        super(build_generator_resnet_9blocks, self).__init__()
X
xiaoting 已提交
55

56 57
        self.conv0 = conv2d(
            num_channels=input_channel,
X
xiaoting 已提交
58 59 60 61 62
            num_filters=32,
            filter_size=7,
            stride=1,
            padding=0,
            stddev=0.02)
63 64
        self.conv1 = conv2d(
            num_channels=32,
X
xiaoting 已提交
65 66 67 68 69
            num_filters=64,
            filter_size=3,
            stride=2,
            padding=1,
            stddev=0.02)
70 71
        self.conv2 = conv2d(
            num_channels=64,
X
xiaoting 已提交
72 73 74 75 76 77
            num_filters=128,
            filter_size=3,
            stride=2,
            padding=1,
            stddev=0.02)
        self.build_resnet_block_list=[]
78
        dim = 128
X
xiaoting 已提交
79 80 81
        for i in range(9):
            Build_Resnet_Block = self.add_sublayer(
                "generator_%d" % (i+1),
82
                build_resnet_block(dim))
X
xiaoting 已提交
83
            self.build_resnet_block_list.append(Build_Resnet_Block)
84 85
        self.deconv0 = DeConv2D(
            num_channels=dim,
X
xiaoting 已提交
86 87 88 89 90 91 92
            num_filters=32*2,
            filter_size=3,
            stride=2,
            stddev=0.02,
            padding=[1, 1],
            outpadding=[0, 1, 0, 1],
            )
93 94
        self.deconv1 = DeConv2D(
            num_channels=32*2,
X
xiaoting 已提交
95 96 97 98 99 100
            num_filters=32,
            filter_size=3,
            stride=2,
            stddev=0.02,
            padding=[1, 1],
            outpadding=[0, 1, 0, 1])
101 102 103
        self.conv3 = conv2d(
            num_channels=32,
            num_filters=input_channel,
X
xiaoting 已提交
104 105 106 107 108 109 110
            filter_size=7,
            stride=1,
            stddev=0.02,
            padding=0,
            relu=False,
            norm=False,
            use_bias=True)
111

X
xiaoting 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125
    def forward(self,inputs):
        pad_input = fluid.layers.pad2d(inputs, [3, 3, 3, 3], mode="reflect")
        y = self.conv0(pad_input)
        y = self.conv1(y)
        y = self.conv2(y)
        for build_resnet_block_i in self.build_resnet_block_list:
            y = build_resnet_block_i(y)
        y = self.deconv0(y)
        y = self.deconv1(y)
        y = fluid.layers.pad2d(y,[3,3,3,3],mode="reflect")
        y = self.conv3(y)
        y = fluid.layers.tanh(y)
        return y

126

X
xiaoting 已提交
127
class build_gen_discriminator(fluid.dygraph.Layer):
128 129
    def __init__(self, input_channel):
        super(build_gen_discriminator, self).__init__()
X
xiaoting 已提交
130
        
131 132
        self.conv0 = conv2d(
            num_channels=input_channel,
X
xiaoting 已提交
133 134 135 136 137 138 139 140
            num_filters=64,
            filter_size=4,
            stride=2,
            stddev=0.02,
            padding=1,
            norm=False,
            use_bias=True,
            relufactor=0.2)
141 142
        self.conv1 = conv2d(
            num_channels=64,
X
xiaoting 已提交
143 144 145 146 147 148
            num_filters=128,
            filter_size=4,
            stride=2,
            stddev=0.02,
            padding=1,
            relufactor=0.2)
149 150
        self.conv2 = conv2d(
            num_channels=128,
X
xiaoting 已提交
151 152 153 154 155 156
            num_filters=256,
            filter_size=4,
            stride=2,
            stddev=0.02,
            padding=1,
            relufactor=0.2)
157 158
        self.conv3 = conv2d(
            num_channels=256,
X
xiaoting 已提交
159 160 161 162 163 164
            num_filters=512,
            filter_size=4,
            stride=1,
            stddev=0.02,
            padding=1,
            relufactor=0.2)
165 166
        self.conv4 = conv2d(
            num_channels=512,
X
xiaoting 已提交
167 168 169 170 171 172 173 174
            num_filters=1,
            filter_size=4,
            stride=1,
            stddev=0.02,
            padding=1,
            norm=False,
            relu=False,
            use_bias=True)
175

X
xiaoting 已提交
176 177 178 179 180 181 182 183 184
    def forward(self,inputs):
        y = self.conv0(inputs)
        y = self.conv1(y)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.conv4(y)
        return y