# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from .base_model import BaseModel from .builder import MODELS from .generators.builder import build_generator from .criterions import build_criterion from ..modules.init import init_weights @MODELS.register() class LapStyleModel(BaseModel): def __init__(self, generator_encode, generator_decode, calc_style_emd_loss=None, calc_content_relt_loss=None, calc_content_loss=None, calc_style_loss=None, content_layers=['r11', 'r21', 'r31', 'r41', 'r51'], style_layers=['r11', 'r21', 'r31', 'r41', 'r51'], content_weight=1.0, style_weight=3.0): super(LapStyleModel, self).__init__() # define generators self.nets['net_enc'] = build_generator(generator_encode) self.nets['net_dec'] = build_generator(generator_decode) init_weights(self.nets['net_dec']) self.set_requires_grad([self.nets['net_enc']], False) # define loss functions self.calc_style_emd_loss = build_criterion(calc_style_emd_loss) self.calc_content_relt_loss = build_criterion(calc_content_relt_loss) self.calc_content_loss = build_criterion(calc_content_loss) self.calc_style_loss = build_criterion(calc_style_loss) self.content_layers = content_layers self.style_layers = style_layers self.content_weight = content_weight self.style_weight = style_weight def setup_input(self, input): self.ci = paddle.to_tensor(input['ci']) self.visual_items['ci'] = self.ci self.si = paddle.to_tensor(input['si']) self.visual_items['si'] = self.si self.image_paths = input['ci_path'] def forward(self): """Run forward pass; called by both functions and .""" self.cF = self.nets['net_enc'](self.ci) self.sF = self.nets['net_enc'](self.si) self.stylized = self.nets['net_dec'](self.cF, self.sF) self.visual_items['stylized'] = self.stylized def backward_dnc(self): self.tF = self.nets['net_enc'](self.stylized) """content loss""" self.loss_c = 0 for layer in self.content_layers: self.loss_c += self.calc_content_loss(self.tF[layer], self.cF[layer], norm=True) self.losses['loss_c'] = self.loss_c """style loss""" self.loss_s = 0 for layer in self.style_layers: self.loss_s += self.calc_style_loss(self.tF[layer], self.sF[layer]) self.losses['loss_s'] = self.loss_s """IDENTITY LOSSES""" self.Icc = self.nets['net_dec'](self.cF, self.cF) self.l_identity1 = self.calc_content_loss(self.Icc, self.ci) self.Fcc = self.nets['net_enc'](self.Icc) self.l_identity2 = 0 for layer in self.content_layers: self.l_identity2 += self.calc_content_loss(self.Fcc[layer], self.cF[layer]) self.losses['l_identity1'] = self.l_identity1 self.losses['l_identity2'] = self.l_identity2 """relative loss""" self.loss_style_remd = self.calc_style_emd_loss( self.tF['r31'], self.sF['r31']) + self.calc_style_emd_loss( self.tF['r41'], self.sF['r41']) self.loss_content_relt = self.calc_content_relt_loss( self.tF['r31'], self.cF['r31']) + self.calc_content_relt_loss( self.tF['r41'], self.cF['r41']) self.losses['loss_style_remd'] = self.loss_style_remd self.losses['loss_content_relt'] = self.loss_content_relt self.loss = self.loss_c * self.content_weight + self.loss_s * self.style_weight +\ self.l_identity1 * 50 + self.l_identity2 * 1 + self.loss_style_remd * 10 + \ self.loss_content_relt * 16 self.loss.backward() return self.loss def train_iter(self, optimizers=None): """Calculate losses, gradients, and update network weights""" self.forward() optimizers['optimG'].clear_grad() self.backward_dnc() self.optimizers['optimG'].step()