未验证 提交 fca1fe35 编写于 作者: W wangna11BD 提交者: GitHub

Support @to_static traing for msvsr (#753)

* Support @to_static traing for msvsr

* fix error for TIPC
上级 01cb542f
...@@ -27,6 +27,8 @@ model: ...@@ -27,6 +27,8 @@ model:
pixel_criterion: pixel_criterion:
name: CharbonnierLoss name: CharbonnierLoss
reduction: mean reduction: mean
# training model under @to_static
to_static: False
dataset: dataset:
train: train:
......
...@@ -31,7 +31,8 @@ class MultiStageVSRModel(BaseSRModel): ...@@ -31,7 +31,8 @@ class MultiStageVSRModel(BaseSRModel):
PP-MSVSR: Multi-Stage Video Super-Resolution, 2021 PP-MSVSR: Multi-Stage Video Super-Resolution, 2021
""" """
def __init__(self, generator, fix_iter, pixel_criterion=None): def __init__(self, generator, fix_iter, pixel_criterion=None, to_static=False,
image_shape=None):
"""Initialize the PP-MSVSR class. """Initialize the PP-MSVSR class.
Args: Args:
...@@ -39,7 +40,9 @@ class MultiStageVSRModel(BaseSRModel): ...@@ -39,7 +40,9 @@ class MultiStageVSRModel(BaseSRModel):
fix_iter (dict): config of fix_iter. fix_iter (dict): config of fix_iter.
pixel_criterion (dict): config of pixel criterion. pixel_criterion (dict): config of pixel criterion.
""" """
super(MultiStageVSRModel, self).__init__(generator, pixel_criterion) super(MultiStageVSRModel, self).__init__(generator, pixel_criterion,
to_static=to_static,
image_shape=image_shape)
self.fix_iter = fix_iter self.fix_iter = fix_iter
self.current_iter = 1 self.current_iter = 1
self.flag = True self.flag = True
......
...@@ -54,4 +54,4 @@ batch_size:64 ...@@ -54,4 +54,4 @@ batch_size:64
fp_items:fp32|fp16 fp_items:fp32|fp16
total_iters:100 total_iters:100
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1 flags:FLAGS_cudnn_exhaustive_search=0
...@@ -54,4 +54,4 @@ batch_size:32|64 ...@@ -54,4 +54,4 @@ batch_size:32|64
fp_items:fp32 fp_items:fp32
total_iters:500 total_iters:500
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1 flags:FLAGS_cudnn_exhaustive_search=0
...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.int ...@@ -17,7 +17,7 @@ norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.int
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:model.to_static=True
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
...@@ -54,6 +54,6 @@ batch_size:2|4 ...@@ -54,6 +54,6 @@ batch_size:2|4
fp_items:fp32|fp16 fp_items:fp32|fp16
total_iters:60 total_iters:60
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1 flags:FLAGS_cudnn_exhaustive_search=0
===========================infer_benchmark_params========================== ===========================infer_benchmark_params==========================
random_infer_input:[{float32,[2,3,180,320]}] random_infer_input:[{float32,[2,3,180,320]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册