diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index fa2dfa0d645d95999038f61f025ad540fcaed478..eed6f89fd9b03031c5cd4e3f159d7c796d74f12e 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -32,6 +32,7 @@ from ..utils.profiler import add_profiler_step class IterLoader: + def __init__(self, dataloader): self._dataloader = dataloader self.iter_loader = iter(self._dataloader) @@ -79,6 +80,7 @@ class Trainer: # | || # save checkpoint (model.nets) \/ """ + def __init__(self, cfg): # base config self.logger = logging.getLogger(__name__) @@ -181,6 +183,22 @@ class Trainer: iter_loader = IterLoader(self.train_dataloader) + # use amp + if self.cfg.amp: + self.logger.info('use AMP to train. AMP level = {}'.format( + self.cfg.amp_level)) + assert self.cfg.model.name == 'MultiStageVSRModel', "AMP only support msvsr model" + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + # need to decorate model and optim if amp_level == 'O2' + if self.cfg.amp_level == 'O2': + # msvsr has only one generator and one optimizer + self.model.nets['generator'], self.optimizers[ + 'optim'] = paddle.amp.decorate( + models=self.model.nets['generator'], + optimizers=self.optimizers['optim'], + level='O2', + save_dtype='float32') + # set model.is_train = True self.model.setup_train_mode(is_train=True) while self.current_iter < (self.total_iters + 1): @@ -195,7 +213,12 @@ class Trainer: # unpack data from dataset and apply preprocessing # data input should be dict self.model.setup_input(data) - self.model.train_iter(self.optimizers) + + if self.cfg.amp: + self.model.train_iter_amp(self.optimizers, scaler, + self.cfg.amp_level) # amp train + else: + self.model.train_iter(self.optimizers) # norm train batch_cost_averager.record( time.time() - step_start_time, diff --git a/ppgan/models/msvsr_model.py b/ppgan/models/msvsr_model.py index 3ee6fbd3acf50983b46fa87f991a2c897875eefe..4642dded9d539b4b2616a867d9b8ba53a761be2c 100644 --- a/ppgan/models/msvsr_model.py +++ b/ppgan/models/msvsr_model.py @@ -30,6 +30,7 @@ class MultiStageVSRModel(BaseSRModel): Paper: PP-MSVSR: Multi-Stage Video Super-Resolution, 2021 """ + def __init__(self, generator, fix_iter, pixel_criterion=None): """Initialize the PP-MSVSR class. @@ -96,6 +97,48 @@ class MultiStageVSRModel(BaseSRModel): self.current_iter += 1 + # amp train with brute force implementation, maybe decorator can simplify this + def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'): + optims['optim'].clear_grad() + if self.fix_iter: + if self.current_iter == 1: + print('Train MSVSR with fixed spynet for', self.fix_iter, + 'iters.') + for name, param in self.nets['generator'].named_parameters(): + if 'spynet' in name: + param.trainable = False + elif self.current_iter >= self.fix_iter + 1 and self.flag: + print('Train all the parameters.') + for name, param in self.nets['generator'].named_parameters(): + param.trainable = True + if 'spynet' in name: + param.optimize_attr['learning_rate'] = 0.25 + self.flag = False + for net in self.nets.values(): + net.find_unused_parameters = False + + # put loss computation in amp context + with paddle.amp.auto_cast(enable=True, level=amp_level): + output = self.nets['generator'](self.lq) + if isinstance(output, (list, tuple)): + out_stage2, output = output + loss_pix_stage2 = self.pixel_criterion(out_stage2, self.gt) + self.losses['loss_pix_stage2'] = loss_pix_stage2 + self.visual_items['output'] = output[:, 0, :, :, :] + # pixel loss + loss_pix = self.pixel_criterion(output, self.gt) + self.losses['loss_pix'] = loss_pix + + self.loss = sum(_value for _key, _value in self.losses.items() + if 'loss_pix' in _key) + scaled_loss = scaler.scale(self.loss) + self.losses['loss'] = scaled_loss + + scaled_loss.backward() + scaler.minimize(optims['optim'], scaled_loss) + + self.current_iter += 1 + def test_iter(self, metrics=None): self.gt = self.gt.cpu() self.nets['generator'].eval() diff --git a/ppgan/utils/options.py b/ppgan/utils/options.py index 2a0abc89db3fc69f0af27ce255156b66a654d46c..d371d63316ff28451dd506fc4b3eee4df5cd66e8 100644 --- a/ppgan/utils/options.py +++ b/ppgan/utils/options.py @@ -45,9 +45,9 @@ def parse_args(): default=False, help='skip validation during training') # config options - parser.add_argument("-o", - "--opt", - nargs='+', + parser.add_argument("-o", + "--opt", + nargs='+', help="set configuration options") #for inference @@ -60,19 +60,31 @@ def parse_args(): help="path to reference images") parser.add_argument("--model_path", default=None, help="model for loading") - # for profiler - parser.add_argument('-p', - '--profiler_options', - type=str, - default=None, - help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + # for profiler + parser.add_argument( + '-p', + '--profiler_options', + type=str, + default=None, + help= + 'The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' ) # fix random numbers by setting seed parser.add_argument('--seed', type=int, default=None, - help='fix random numbers by setting seed\".' - ) + help='fix random numbers by setting seed\".') + + # add for amp training + parser.add_argument('--amp', + action='store_true', + default=False, + help='whether to enable amp training') + parser.add_argument('--amp_level', + type=str, + default='O1', + choices=['O1', 'O2'], + help='level of amp training; O2 represent pure fp16') args = parser.parse_args() return args diff --git a/ppgan/utils/setup.py b/ppgan/utils/setup.py index 78df10923cd905f8b175f6005efc0ecf472d3be2..5531076a9ce74a34e6ed3326ea180a3e5942733a 100644 --- a/ppgan/utils/setup.py +++ b/ppgan/utils/setup.py @@ -19,6 +19,7 @@ import numpy as np import random from .logger import setup_logger + def setup(args, cfg): if args.evaluate_only: cfg.is_train = False @@ -44,10 +45,13 @@ def setup(args, cfg): paddle.set_device('gpu') else: paddle.set_device('cpu') - + if args.seed: paddle.seed(args.seed) random.seed(args.seed) - np.random.seed(args.seed) + np.random.seed(args.seed) paddle.framework.random._manual_program_seed(args.seed) - + + # add amp and amp_level args into cfg + cfg['amp'] = args.amp + cfg['amp_level'] = args.amp_level diff --git a/test_tipc/readme.md b/test_tipc/README.md similarity index 93% rename from test_tipc/readme.md rename to test_tipc/README.md index 3d4778c32a1b4185ad495fc610b9e85e6b04b90a..4ffe955144d77734208e74fb783aaf984028b947 100644 --- a/test_tipc/readme.md +++ b/test_tipc/README.md @@ -57,9 +57,8 @@ test_tipc/ ### 测试流程 使用本工具,可以测试不同功能的支持情况,以及预测结果是否对齐,测试流程如下: -
- -
+ +![img](https://user-images.githubusercontent.com/79366697/185377097-a0f852a8-2d78-45ae-84ba-ae71b799d738.png) 1. 运行prepare.sh准备测试所需数据和模型; 2. 运行要测试的功能对应的测试脚本`test_*.sh`,产出log,由log可以看到不同配置是否运行成功; @@ -72,4 +71,4 @@ test_tipc/ #### 更多教程 各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程: -[test_train_inference_python 使用](docs/test_train_inference_python.md) +- [test_train_inference_python 使用](docs/test_train_inference_python.md): 测试基于Python的模型训练、评估、推理等基本功能 diff --git a/test_tipc/benchmark_train.sh b/test_tipc/benchmark_train.sh index c6e9e73a550134fef9a472085103ad20452194d9..9ea230c6c4cb1c4f7d02e41b736e904673677b68 100644 --- a/test_tipc/benchmark_train.sh +++ b/test_tipc/benchmark_train.sh @@ -4,15 +4,15 @@ source test_tipc/common_func.sh # set env python=python export model_branch=`git symbolic-ref HEAD 2>/dev/null | cut -d"/" -f 3` -export model_commit=$(git log|head -n1|awk '{print $2}') +export model_commit=$(git log|head -n1|awk '{print $2}') export str_tmp=$(echo `pip list|grep paddlepaddle-gpu|awk -F ' ' '{print $2}'`) export frame_version=${str_tmp%%.post*} export frame_commit=$(echo `${python} -c "import paddle;print(paddle.version.commit)"`) -# run benchmark sh +# run benchmark sh # Usage: # bash run_benchmark_train.sh config.txt params -# or +# or # bash run_benchmark_train.sh config.txt function func_parser_params(){ @@ -100,6 +100,7 @@ for _flag in ${flags_list[*]}; do done # set log_name +BENCHMARK_ROOT=./ # self-test only repo_name=$(get_repo_name ) SAVE_LOG=${BENCHMARK_LOG_DIR:-$(pwd)} # */benchmark_log mkdir -p "${SAVE_LOG}/benchmark_log/" @@ -149,11 +150,11 @@ else fi IFS="|" -for batch_size in ${batch_size_list[*]}; do +for batch_size in ${batch_size_list[*]}; do for precision in ${fp_items_list[*]}; do for device_num in ${device_num_list[*]}; do # sed batchsize and precision - #func_sed_params "$FILENAME" "${line_precision}" "$precision" + func_sed_params "$FILENAME" "${line_precision}" "$precision" func_sed_params "$FILENAME" "${line_batchsize}" "$MODE=$batch_size" func_sed_params "$FILENAME" "${line_epoch}" "$MODE=$epoch" gpu_id=$(set_gpu_id $device_num) @@ -162,7 +163,7 @@ for batch_size in ${batch_size_list[*]}; do log_path="$SAVE_LOG/profiling_log" mkdir -p $log_path log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_profiling" - func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id + func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id # set profile_option params tmp=`sed -i "${line_profile}s/.*/${profile_option}/" "${FILENAME}"` @@ -214,7 +215,7 @@ for batch_size in ${batch_size_list[*]}; do mkdir -p $speed_log_path log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_log" speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_speed" - func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id + func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id func_sed_params "$FILENAME" "${line_profile}" "null" # sed --profile_option as null cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " echo $cmd @@ -244,4 +245,4 @@ for batch_size in ${batch_size_list[*]}; do fi done done -done \ No newline at end of file +done diff --git a/test_tipc/configs/msvsr/train_amp_infer_python.txt b/test_tipc/configs/msvsr/train_amp_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..2de77fb11980d56441a266ad001016032f8d4fa0 --- /dev/null +++ b/test_tipc/configs/msvsr/train_amp_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:msvsr +python:python3.7 +gpu_list:0 +## +auto_cast:null +total_iters:lite_train_lite_infer=10|lite_train_whole_infer=10|whole_train_whole_infer=200 +output_dir:./output/ +dataset.train.batch_size:lite_train_lite_infer=1|whole_train_whole_infer=1 +pretrained_model:null +train_model_name:msvsr_reds*/*checkpoint.pdparams +train_infer_img_dir:./data/msvsr_reds/test +null:null +## +trainer:amp_train +amp_train:tools/main.py --amp --amp_level O1 -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.num_frames=2 +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +--output_dir:./output/ +load:null +norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +inference_dir:inference +train_model:./inference/msvsr/multistagevsrmodel_generator +infer_export:null +infer_quant:False +inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_frames=2 --output_path test_tipc/output/ +--device:cpu +null:null +null:null +null:null +null:null +null:null +--model_path: +null:null +null:null +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[2,3,180,320]}] diff --git a/test_tipc/configs/msvsr/train_infer_python.txt b/test_tipc/configs/msvsr/train_infer_python.txt index 288290edfaedc8b14ce5416e186cb31d995d13ee..1719ae11bc35f5aebc037f0401e044991a2772ee 100644 --- a/test_tipc/configs/msvsr/train_infer_python.txt +++ b/test_tipc/configs/msvsr/train_infer_python.txt @@ -13,22 +13,22 @@ train_infer_img_dir:./data/msvsr_reds/test null:null ## trainer:norm_train -norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.num_frames=2 +norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.interval=2 snapshot_config.interval=50 dataset.train.dataset.num_frames=15 pact_train:null fpgm_train:null distill_train:null null:null null:null ## -===========================eval_params=========================== +===========================eval_params=========================== eval:null null:null ## ===========================infer_params=========================== --output_dir:./output/ load:null -norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load -quant_export:null +norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --model_name inference --load +quant_export:null fpgm_export:null distill_export:null export1:null @@ -49,5 +49,11 @@ null:null null:null --benchmark:True null:null +===========================train_benchmark_params========================== +batch_size:4 +fp_items:fp32 +total_iters:60 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[2,3,180,320]}] diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md index 4f33d3e988741d3b4595b58a2a137603e765c2ed..acd85d66370f7273b2bd8341f3ad3b598b018529 100644 --- a/test_tipc/docs/benchmark_train.md +++ b/test_tipc/docs/benchmark_train.md @@ -9,7 +9,7 @@ ```shell # 运行格式:bash test_tipc/prepare.sh train_benchmark.txt mode -bash test_tipc/prepare.sh test_tipc/configs/basicvsr/train_benchmark.txt benchmark_train +bash test_tipc/prepare.sh test_tipc/configs/msvsr/train_infer_python.txt benchmark_train ``` ## 1.2 功能测试 @@ -17,13 +17,13 @@ bash test_tipc/prepare.sh test_tipc/configs/basicvsr/train_benchmark.txt benchma ```shell # 运行格式:bash test_tipc/benchmark_train.sh train_benchmark.txt mode -bash test_tipc/benchmark_train.sh test_tipc/configs/basicvsr/train_infer_python.txt benchmark_train +bash test_tipc/benchmark_train.sh test_tipc/configs/msvsr/train_infer_python.txt benchmark_train ``` `test_tipc/benchmark_train.sh`支持根据传入的第三个参数实现只运行某一个训练配置,如下: ```shell # 运行格式:bash test_tipc/benchmark_train.sh train_benchmark.txt mode -bash test_tipc/benchmark_train.sh test_tipc/configs/basicvsr/train_infer_python.txt benchmark_train dynamic_bs4_fp32_DP_N1C1 +bash test_tipc/benchmark_train.sh test_tipc/configs/msvsr/train_infer_python.txt benchmark_train dynamic_bs4_fp32_DP_N1C1 ``` dynamic_bs4_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式如下: `${modeltype}_${batch_size}_${fp_item}_${run_mode}_${device_num}` @@ -42,11 +42,11 @@ dynamic_bs4_fp32_DP_N1C1为test_tipc/benchmark_train.sh传入的参数,格式 ``` train_log/ ├── index -│ ├── PaddleGAN_basicvsr_bs4_fp32_SingleP_DP_N1C1_speed -│ └── PaddleGAN_basicvsr_bs4_fp32_SingleP_DP_N1C4_speed +│ ├── PaddleGAN_msvsr_bs4_fp32_SingleP_DP_N1C1_speed +│ └── PaddleGAN_msvsr_bs4_fp32_SingleP_DP_N1C4_speed ├── profiling_log -│ └── PaddleGAN_basicvsr_bs4_fp32_SingleP_DP_N1C1_profiling +│ └── PaddleGAN_msvsr_bs4_fp32_SingleP_DP_N1C1_profiling └── train_log - ├── PaddleGAN_basicvsr_bs4_fp32_SingleP_DP_N1C1_log - └── PaddleGAN_basicvsr_bs4_fp32_MultiP_DP_N1C4_log + ├── PaddleGAN_msvsr_bs4_fp32_SingleP_DP_N1C1_log + └── PaddleGAN_msvsr_bs4_fp32_MultiP_DP_N1C4_log ``` diff --git a/test_tipc/docs/test.png b/test_tipc/docs/test.png new file mode 100644 index 0000000000000000000000000000000000000000..f99f23d7050eb61879cf317c0d7728ef14531b08 Binary files /dev/null and b/test_tipc/docs/test.png differ diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 1af5e4c3549e95eefc8aadb493d7ed09b65e970d..b39b28ed46fc157b2d90f205f870f4af9694c72a 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -172,5 +172,10 @@ elif [ ${MODE} = "whole_infer" ];then mkdir -p ./data/singan mv ./data/SinGAN-official_images/Images/stone.png ./data/singan fi - +elif [ ${MODE} = "benchmark_train" ];then + if [ ${model_name} = "msvsr" ]; then + rm -rf ./data/reds* + wget -nc -P ./data/ https://paddlegan.bj.bcebos.com/datasets/reds_lite.tar --no-check-certificate + cd ./data/ && tar xf reds_lite.tar && cd ../ + fi fi diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh index bc54fac8f856eee64058c487f603e0c9c220c45b..e270f9617c66c79a9fb4332155b036dadf9d9a13 100644 --- a/test_tipc/test_train_inference_python.sh +++ b/test_tipc/test_train_inference_python.sh @@ -48,11 +48,11 @@ norm_export=$(func_parser_value "${lines[29]}") inference_dir=$(func_parser_value "${lines[35]}") -# parser inference model +# parser inference model infer_model_dir_list=$(func_parser_value "${lines[36]}") infer_export_list=$(func_parser_value "${lines[37]}") infer_is_quant=$(func_parser_value "${lines[38]}") -# parser inference +# parser inference inference_py=$(func_parser_value "${lines[39]}") use_gpu_key=$(func_parser_key "${lines[40]}") use_gpu_list=$(func_parser_value "${lines[40]}") @@ -85,7 +85,7 @@ function func_inference(){ _log_path=$4 _img_dir=$5 _flag_quant=$6 - # inference + # inference for use_gpu in ${use_gpu_list[*]}; do if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then for use_mkldnn in ${use_mkldnn_list[*]}; do @@ -96,7 +96,7 @@ function func_inference(){ for batch_size in ${batch_size_list[*]}; do for precision in ${precision_list[*]}; do set_precision=$(func_set_params "${precision_key}" "${precision}") - + _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log" set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}") set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}") @@ -118,7 +118,7 @@ function func_inference(){ for precision in ${precision_list[*]}; do if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then continue - fi + fi if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then continue fi @@ -139,7 +139,7 @@ function func_inference(){ last_status=${PIPESTATUS[0]} eval "cat ${_save_log_path}" status_check $last_status "${command}" "${status_log}" - + done done done @@ -169,7 +169,7 @@ if [ ${MODE} = "whole_infer" ]; then set_export_weight=$(func_set_params "${export_weight}" "${infer_model}") set_save_infer_key="${save_infer_key} ${save_infer_dir}" export_cmd="${python} ${infer_run_exports[Count]} ${set_export_weight} ${set_save_infer_key}" - echo ${infer_run_exports[Count]} + echo ${infer_run_exports[Count]} echo $export_cmd eval $export_cmd status_export=$? @@ -207,17 +207,17 @@ else IFS="|" env=" " fi - for autocast in ${autocast_list[*]}; do - if [ ${autocast} = "amp" ]; then - set_amp_config="Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True" + for autocast in ${autocast_list[*]}; do + if [ ${autocast} = "fp16" ]; then + set_amp_config="--amp" else set_amp_config=" " - fi - for trainer in ${trainer_list[*]}; do + fi + for trainer in ${trainer_list[*]}; do flag_quant=False run_train=${norm_trainer} run_export=${norm_export} - + if [ ${run_train} = "null" ]; then continue fi @@ -239,11 +239,11 @@ else fi set_save_model=$(func_set_params "${save_model_key}" "${save_log}") if [ ${#gpu} -le 2 ];then # train with cpu or single gpu - cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_amp_config} " + cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config}" elif [ ${#ips} -le 26 ];then # train with multi-gpu - cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_amp_config}" + cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config}" else # train with multi-machine - cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_amp_config}" + cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_train_params1} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_amp_config}" fi # run train eval "unset CUDA_VISIBLE_DEVICES" @@ -253,17 +253,17 @@ else status_check $? "${cmd}" "${status_log}" set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${train_model_name}") - # save norm trained models to set pretrain for pact training and fpgm training - - # run eval + # save norm trained models to set pretrain for pact training and fpgm training + + # run eval if [ ${eval_py} != "null" ]; then set_eval_params1=$(func_set_params "${eval_key1}" "${eval_value1}") - eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}" + eval_cmd="${python} ${eval_py} ${set_eval_pretrain} ${set_use_gpu} ${set_eval_params1}" eval $eval_cmd status_check $? "${eval_cmd}" "${status_log}" fi # run export model - if [ ${run_export} != "null" ]; then + if [ ${run_export} != "null" ]; then # run export model save_infer_path="${save_log}" set_export_weight="${save_log}/${train_model_name}" @@ -272,7 +272,7 @@ else export_cmd="${python} ${run_export} ${set_export_weight_path} ${set_save_infer_key}" eval "$export_cmd" status_check $? "${export_cmd}" "${status_log}" - + #run inference eval $env save_infer_path="${save_log}" @@ -282,11 +282,10 @@ else infer_model_dir=${save_infer_path} fi func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" - + eval "unset CUDA_VISIBLE_DEVICES" fi - done # done with: for trainer in ${trainer_list[*]}; do - done # done with: for autocast in ${autocast_list[*]}; do + done # done with: for trainer in ${trainer_list[*]}; do + done # done with: for autocast in ${autocast_list[*]}; do done # done with: for gpu in ${gpu_list[*]}; do fi # end if [ ${MODE} = "infer" ]; then -