sr_model.py 2.0 KB
Newer Older
L
LielinJiang 已提交
1 2 3
from collections import OrderedDict
import paddle
import paddle.nn as nn
L
LielinJiang 已提交
4

L
LielinJiang 已提交
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..solver import build_optimizer
from .base_model import BaseModel
from .losses import GANLoss
from .builder import MODELS

import importlib
from collections import OrderedDict
from copy import deepcopy
from os import path as osp
from .builder import MODELS


@MODELS.register()
class SRModel(BaseModel):
    """Base SR model for single image super-resolution."""
    def __init__(self, cfg):
        super(SRModel, self).__init__(cfg)

        self.model_names = ['G']
L
LielinJiang 已提交
26

L
LielinJiang 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
        self.netG = build_generator(cfg.model.generator)
        self.visual_names = ['lq', 'output', 'gt']

        self.loss_names = ['l_total']

        self.optimizers = []
        if self.isTrain:
            self.criterionL1 = paddle.nn.L1Loss()

            self.build_lr_scheduler()
            self.optimizer_G = build_optimizer(
                cfg.optimizer,
                self.lr_scheduler,
                parameter_list=self.netG.parameters())
            self.optimizers.append(self.optimizer_G)

    def set_input(self, input):
        self.lq = paddle.to_tensor(input['lq'])
        if 'gt' in input:
            self.gt = paddle.to_tensor(input['gt'])
        self.image_paths = input['lq_path']

    def forward(self):
        pass
L
LielinJiang 已提交
51

L
LielinJiang 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
    def test(self):
        """Forward function used in test time.
        """
        with paddle.no_grad():
            self.output = self.netG(self.lq)

    def optimize_parameters(self):
        self.optimizer_G.clear_grad()
        self.output = self.netG(self.lq)

        l_total = 0
        loss_dict = OrderedDict()
        # pixel loss
        if self.criterionL1:
            l_pix = self.criterionL1(self.output, self.gt)
            l_total += l_pix
            loss_dict['l_pix'] = l_pix

        l_total.backward()
        self.loss_l_total = l_total
        self.optimizer_G.step()