diff --git a/configs/edvr_m_wo_tsa.yaml b/configs/edvr_m_wo_tsa.yaml index 3feb891cbf06508ee8c3ad8041cfa84c30953b4f..6a24f470bfdc049fa861d0cca008ac6343260ee5 100644 --- a/configs/edvr_m_wo_tsa.yaml +++ b/configs/edvr_m_wo_tsa.yaml @@ -24,6 +24,8 @@ model: w_TSA: False pixel_criterion: name: CharbonnierLoss + # training model under @to_static + to_static: False export_model: - {name: 'generator', inputs_num: 1} diff --git a/configs/esrgan_psnr_x4_div2k.yaml b/configs/esrgan_psnr_x4_div2k.yaml index 8373e62f8e55a70a92aac1a270e95b440edbd6db..2f3504a80839235df8153426a3cf5030ec104ec0 100644 --- a/configs/esrgan_psnr_x4_div2k.yaml +++ b/configs/esrgan_psnr_x4_div2k.yaml @@ -14,6 +14,8 @@ model: nb: 23 pixel_criterion: name: L1Loss + # training model under @to_static + to_static: False export_model: - {name: 'generator', inputs_num: 1} diff --git a/configs/pix2pix_facades.yaml b/configs/pix2pix_facades.yaml index ffe127f41097932b7947bd9a356b3211bb246a13..ec30b188dad8863b48461d1207d008b4ad08199c 100644 --- a/configs/pix2pix_facades.yaml +++ b/configs/pix2pix_facades.yaml @@ -24,6 +24,8 @@ model: gan_criterion: name: GANLoss gan_mode: vanilla + # training model under @to_static + to_static: False dataset: train: diff --git a/ppgan/models/edvr_model.py b/ppgan/models/edvr_model.py index 3b5c50a772d6fc645a079719bb79b9274a144399..387714275cf7cc9e42f9e3eb34bc8c74948124c3 100644 --- a/ppgan/models/edvr_model.py +++ b/ppgan/models/edvr_model.py @@ -15,6 +15,7 @@ import paddle import paddle.nn as nn +from .base_model import apply_to_static from .builder import MODELS from .sr_model import BaseSRModel from .generators.edvr import ResidualBlockNoBN, DCNPack @@ -28,7 +29,8 @@ class EDVRModel(BaseSRModel): Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. """ - def __init__(self, generator, tsa_iter, pixel_criterion=None): + def __init__(self, generator, tsa_iter, pixel_criterion=None, to_static=False, + image_shape=None): """Initialize the EDVR class. Args: @@ -36,7 +38,9 @@ class EDVRModel(BaseSRModel): tsa_iter (dict): config of tsa_iter. pixel_criterion (dict): config of pixel criterion. """ - super(EDVRModel, self).__init__(generator, pixel_criterion) + super(EDVRModel, self).__init__(generator, pixel_criterion, + to_static=to_static, + image_shape=image_shape) self.tsa_iter = tsa_iter self.current_iter = 1 init_edvr_weight(self.nets['generator']) diff --git a/ppgan/models/pix2pix_model.py b/ppgan/models/pix2pix_model.py index 2c8d552378ef62c9c9371f92045ea93ae0a4d6d8..a5784e8c8b0effcc3f6834e33c193946d8bcd17a 100644 --- a/ppgan/models/pix2pix_model.py +++ b/ppgan/models/pix2pix_model.py @@ -13,7 +13,7 @@ # limitations under the License. import paddle -from .base_model import BaseModel +from .base_model import BaseModel, apply_to_static from .builder import MODELS from .generators.builder import build_generator @@ -36,7 +36,9 @@ class Pix2PixModel(BaseModel): discriminator=None, pixel_criterion=None, gan_criterion=None, - direction='a2b'): + direction='a2b', + to_static=False, + image_shape=None): """Initialize the pix2pix class. Args: @@ -51,11 +53,15 @@ class Pix2PixModel(BaseModel): # define networks (both generator and discriminator) self.nets['netG'] = build_generator(generator) init_weights(self.nets['netG']) + # set @to_static for benchmark, skip this by default. + apply_to_static(to_static, image_shape, self.nets['netG']) # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc if discriminator: self.nets['netD'] = build_discriminator(discriminator) init_weights(self.nets['netD']) + # set @to_static for benchmark, skip this by default. + apply_to_static(to_static, image_shape, self.nets['netD']) if pixel_criterion: self.pixel_criterion = build_criterion(pixel_criterion) diff --git a/ppgan/models/sr_model.py b/ppgan/models/sr_model.py index e81e1f370d1663ce522ef74c4f507f61dd50479e..7a0db5513bd52e3071ea95ed84a9e8b1c61fc40f 100644 --- a/ppgan/models/sr_model.py +++ b/ppgan/models/sr_model.py @@ -17,7 +17,7 @@ import paddle.nn as nn from .generators.builder import build_generator from .criterions.builder import build_criterion -from .base_model import BaseModel +from .base_model import BaseModel, apply_to_static from .builder import MODELS from ..utils.visual import tensor2img from ..modules.init import reset_parameters @@ -28,7 +28,8 @@ class BaseSRModel(BaseModel): """Base SR model for single image super-resolution. """ - def __init__(self, generator, pixel_criterion=None, use_init_weight=False): + def __init__(self, generator, pixel_criterion=None, use_init_weight=False, to_static=False, + image_shape=None): """ Args: generator (dict): config of generator. @@ -37,6 +38,8 @@ class BaseSRModel(BaseModel): super(BaseSRModel, self).__init__() self.nets['generator'] = build_generator(generator) + # set @to_static for benchmark, skip this by default. + apply_to_static(to_static, image_shape, self.nets['generator']) if pixel_criterion: self.pixel_criterion = build_criterion(pixel_criterion) diff --git a/test_tipc/configs/Pix2pix/train_infer_python.txt b/test_tipc/configs/Pix2pix/train_infer_python.txt index 3f4d8238a9ce66effb475bf4f52e1bf86e58db7b..e1aed5ebb886a926b60a1635c6e3df758200708d 100644 --- a/test_tipc/configs/Pix2pix/train_infer_python.txt +++ b/test_tipc/configs/Pix2pix/train_infer_python.txt @@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/pix2pix_facades.yaml --seed 123 -o log_confi pact_train:null fpgm_train:null distill_train:null -null:null +to_static_train:model.to_static=True null:null ## ===========================eval_params=========================== diff --git a/test_tipc/configs/edvr/train_infer_python.txt b/test_tipc/configs/edvr/train_infer_python.txt index acc1875c9c24772781ee976e9f7e7b2f59ef166d..569a46dfcdca030cb5040c7dd170e0de382c09cc 100644 --- a/test_tipc/configs/edvr/train_infer_python.txt +++ b/test_tipc/configs/edvr/train_infer_python.txt @@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config. pact_train:null fpgm_train:null distill_train:null -null:null +to_static_train:model.to_static=True null:null ## ===========================eval_params=========================== diff --git a/test_tipc/configs/esrgan/train_infer_python.txt b/test_tipc/configs/esrgan/train_infer_python.txt index dfbb98d92725f1bc74058a206d7a7310b391a911..4e8a5f3b99231f38121c42949e1243e701322469 100644 --- a/test_tipc/configs/esrgan/train_infer_python.txt +++ b/test_tipc/configs/esrgan/train_infer_python.txt @@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/esrgan_psnr_x4_div2k.yaml --seed 123 -o log_ pact_train:null fpgm_train:null distill_train:null -null:null +to_static_train:model.to_static=True null:null ## ===========================eval_params===========================