未验证 提交 01cb542f 编写于 作者: W wangna11BD 提交者: GitHub

Support @to_static traing for edvr pix2pix and esrgan (#750)

上级 461bc8cd
...@@ -24,6 +24,8 @@ model: ...@@ -24,6 +24,8 @@ model:
w_TSA: False w_TSA: False
pixel_criterion: pixel_criterion:
name: CharbonnierLoss name: CharbonnierLoss
# training model under @to_static
to_static: False
export_model: export_model:
- {name: 'generator', inputs_num: 1} - {name: 'generator', inputs_num: 1}
......
...@@ -14,6 +14,8 @@ model: ...@@ -14,6 +14,8 @@ model:
nb: 23 nb: 23
pixel_criterion: pixel_criterion:
name: L1Loss name: L1Loss
# training model under @to_static
to_static: False
export_model: export_model:
- {name: 'generator', inputs_num: 1} - {name: 'generator', inputs_num: 1}
......
...@@ -24,6 +24,8 @@ model: ...@@ -24,6 +24,8 @@ model:
gan_criterion: gan_criterion:
name: GANLoss name: GANLoss
gan_mode: vanilla gan_mode: vanilla
# training model under @to_static
to_static: False
dataset: dataset:
train: train:
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from .base_model import apply_to_static
from .builder import MODELS from .builder import MODELS
from .sr_model import BaseSRModel from .sr_model import BaseSRModel
from .generators.edvr import ResidualBlockNoBN, DCNPack from .generators.edvr import ResidualBlockNoBN, DCNPack
...@@ -28,7 +29,8 @@ class EDVRModel(BaseSRModel): ...@@ -28,7 +29,8 @@ class EDVRModel(BaseSRModel):
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
""" """
def __init__(self, generator, tsa_iter, pixel_criterion=None): def __init__(self, generator, tsa_iter, pixel_criterion=None, to_static=False,
image_shape=None):
"""Initialize the EDVR class. """Initialize the EDVR class.
Args: Args:
...@@ -36,7 +38,9 @@ class EDVRModel(BaseSRModel): ...@@ -36,7 +38,9 @@ class EDVRModel(BaseSRModel):
tsa_iter (dict): config of tsa_iter. tsa_iter (dict): config of tsa_iter.
pixel_criterion (dict): config of pixel criterion. pixel_criterion (dict): config of pixel criterion.
""" """
super(EDVRModel, self).__init__(generator, pixel_criterion) super(EDVRModel, self).__init__(generator, pixel_criterion,
to_static=to_static,
image_shape=image_shape)
self.tsa_iter = tsa_iter self.tsa_iter = tsa_iter
self.current_iter = 1 self.current_iter = 1
init_edvr_weight(self.nets['generator']) init_edvr_weight(self.nets['generator'])
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from .base_model import BaseModel from .base_model import BaseModel, apply_to_static
from .builder import MODELS from .builder import MODELS
from .generators.builder import build_generator from .generators.builder import build_generator
...@@ -36,7 +36,9 @@ class Pix2PixModel(BaseModel): ...@@ -36,7 +36,9 @@ class Pix2PixModel(BaseModel):
discriminator=None, discriminator=None,
pixel_criterion=None, pixel_criterion=None,
gan_criterion=None, gan_criterion=None,
direction='a2b'): direction='a2b',
to_static=False,
image_shape=None):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
Args: Args:
...@@ -51,11 +53,15 @@ class Pix2PixModel(BaseModel): ...@@ -51,11 +53,15 @@ class Pix2PixModel(BaseModel):
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.nets['netG'] = build_generator(generator) self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG']) init_weights(self.nets['netG'])
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if discriminator: if discriminator:
self.nets['netD'] = build_discriminator(discriminator) self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD']) init_weights(self.nets['netD'])
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['netD'])
if pixel_criterion: if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion) self.pixel_criterion = build_criterion(pixel_criterion)
......
...@@ -17,7 +17,7 @@ import paddle.nn as nn ...@@ -17,7 +17,7 @@ import paddle.nn as nn
from .generators.builder import build_generator from .generators.builder import build_generator
from .criterions.builder import build_criterion from .criterions.builder import build_criterion
from .base_model import BaseModel from .base_model import BaseModel, apply_to_static
from .builder import MODELS from .builder import MODELS
from ..utils.visual import tensor2img from ..utils.visual import tensor2img
from ..modules.init import reset_parameters from ..modules.init import reset_parameters
...@@ -28,7 +28,8 @@ class BaseSRModel(BaseModel): ...@@ -28,7 +28,8 @@ class BaseSRModel(BaseModel):
"""Base SR model for single image super-resolution. """Base SR model for single image super-resolution.
""" """
def __init__(self, generator, pixel_criterion=None, use_init_weight=False): def __init__(self, generator, pixel_criterion=None, use_init_weight=False, to_static=False,
image_shape=None):
""" """
Args: Args:
generator (dict): config of generator. generator (dict): config of generator.
...@@ -37,6 +38,8 @@ class BaseSRModel(BaseModel): ...@@ -37,6 +38,8 @@ class BaseSRModel(BaseModel):
super(BaseSRModel, self).__init__() super(BaseSRModel, self).__init__()
self.nets['generator'] = build_generator(generator) self.nets['generator'] = build_generator(generator)
# set @to_static for benchmark, skip this by default.
apply_to_static(to_static, image_shape, self.nets['generator'])
if pixel_criterion: if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion) self.pixel_criterion = build_criterion(pixel_criterion)
......
...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o log_confi ...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o log_confi
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:model.to_static=True
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config. ...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:model.to_static=True
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_ ...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:model.to_static=True
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册