From fca1fe354d2bdc64c20a0d4b4d088ee06914f62e Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Fri, 24 Feb 2023 11:59:35 +0800 Subject: [PATCH] Support @to_static traing for msvsr (#753) * Support @to_static traing for msvsr * fix error for TIPC --- configs/msvsr_reds.yaml | 2 ++ ppgan/models/msvsr_model.py | 7 +++++-- test_tipc/configs/edvr/train_infer_python.txt | 2 +- test_tipc/configs/esrgan/train_infer_python.txt | 2 +- test_tipc/configs/msvsr/train_infer_python.txt | 4 ++-- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/configs/msvsr_reds.yaml b/configs/msvsr_reds.yaml index 2efb181..78e0f58 100644 --- a/configs/msvsr_reds.yaml +++ b/configs/msvsr_reds.yaml @@ -27,6 +27,8 @@ model: pixel_criterion: name: CharbonnierLoss reduction: mean + # training model under @to_static + to_static: False dataset: train: diff --git a/ppgan/models/msvsr_model.py b/ppgan/models/msvsr_model.py index 04bdfb4..a7fe914 100644 --- a/ppgan/models/msvsr_model.py +++ b/ppgan/models/msvsr_model.py @@ -31,7 +31,8 @@ class MultiStageVSRModel(BaseSRModel): PP-MSVSR: Multi-Stage Video Super-Resolution, 2021 """ - def __init__(self, generator, fix_iter, pixel_criterion=None): + def __init__(self, generator, fix_iter, pixel_criterion=None, to_static=False, + image_shape=None): """Initialize the PP-MSVSR class. Args: @@ -39,7 +40,9 @@ class MultiStageVSRModel(BaseSRModel): fix_iter (dict): config of fix_iter. pixel_criterion (dict): config of pixel criterion. """ - super(MultiStageVSRModel, self).__init__(generator, pixel_criterion) + super(MultiStageVSRModel, self).__init__(generator, pixel_criterion, + to_static=to_static, + image_shape=image_shape) self.fix_iter = fix_iter self.current_iter = 1 self.flag = True diff --git a/test_tipc/configs/edvr/train_infer_python.txt b/test_tipc/configs/edvr/train_infer_python.txt index 569a46d..9379f9c 100644 --- a/test_tipc/configs/edvr/train_infer_python.txt +++ b/test_tipc/configs/edvr/train_infer_python.txt @@ -54,4 +54,4 @@ batch_size:64 fp_items:fp32|fp16 total_iters:100 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile -flags:FLAGS_cudnn_exhaustive_search=1 +flags:FLAGS_cudnn_exhaustive_search=0 diff --git a/test_tipc/configs/esrgan/train_infer_python.txt b/test_tipc/configs/esrgan/train_infer_python.txt index 4e8a5f3..c70b93d 100644 --- a/test_tipc/configs/esrgan/train_infer_python.txt +++ b/test_tipc/configs/esrgan/train_infer_python.txt @@ -54,4 +54,4 @@ batch_size:32|64 fp_items:fp32 total_iters:500 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile -flags:FLAGS_cudnn_exhaustive_search=1 +flags:FLAGS_cudnn_exhaustive_search=0 diff --git a/test_tipc/configs/msvsr/train_infer_python.txt b/test_tipc/configs/msvsr/train_infer_python.txt index 7a0b1b8..2153983 100644 --- a/test_tipc/configs/msvsr/train_infer_python.txt +++ b/test_tipc/configs/msvsr/train_infer_python.txt @@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.int pact_train:null fpgm_train:null distill_train:null -null:null +to_static_train:model.to_static=True null:null ## ===========================eval_params=========================== @@ -54,6 +54,6 @@ batch_size:2|4 fp_items:fp32|fp16 total_iters:60 --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile -flags:FLAGS_cudnn_exhaustive_search=1 +flags:FLAGS_cudnn_exhaustive_search=0 ===========================infer_benchmark_params========================== random_infer_input:[{float32,[2,3,180,320]}] -- GitLab