From aaae49584f7d336c80a714e067d870b2a6f69493 Mon Sep 17 00:00:00 2001 From: OneYearIsEnough <81819512+OneYearIsEnough@users.noreply.github.com> Date: Tue, 1 Feb 2022 17:46:42 +0800 Subject: [PATCH] [Feature] Add eta function in model's training stage (#5380) * [Feature] Add eta function in model's training stage * [Feature] Add eta function in model's training stage * [Feature] Add eta function in model's training stage * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals. * [BugFix] Fix offset bug, residual idxes should -1 --- ppocr/utils/utility.py | 19 ++++++++++++++++ tools/program.py | 51 +++++++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py index 76484dfd..dc2a6e74 100755 --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -105,3 +105,22 @@ def set_seed(seed=1024): random.seed(seed) np.random.seed(seed) paddle.seed(seed) + + +class AverageMeter: + def __init__(self): + self.reset() + + def reset(self): + """reset""" + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """update""" + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/tools/program.py b/tools/program.py index 10299940..5ffb93d1 100755 --- a/tools/program.py +++ b/tools/program.py @@ -21,7 +21,7 @@ import sys import platform import yaml import time -import shutil +import datetime import paddle import paddle.distributed as dist from tqdm import tqdm @@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter from ppocr.utils.stats import TrainingStats from ppocr.utils.save_load import save_model -from ppocr.utils.utility import print_dict +from ppocr.utils.utility import print_dict, AverageMeter from ppocr.utils.logging import get_logger from ppocr.utils import profiler from ppocr.data import build_dataloader -import numpy as np class ArgsParser(ArgumentParser): @@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser): '--profiler_options', type=str, default=None, - help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' + help='The option of profiler, which should be in format ' \ + '\"key1=value1;key2=value2;key3=value3\".' ) def parse_args(self, argv=None): @@ -99,7 +99,8 @@ def merge_config(config, opts): sub_keys = key.split('.') assert ( sub_keys[0] in config - ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( + ), "the sub_keys can only be one of global_config: {}, but get: " \ + "{}, please check your running command".format( config.keys(), sub_keys[0]) cur = config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): @@ -160,11 +161,13 @@ def train(config, eval_batch_step = eval_batch_step[1] if len(valid_dataloader) == 0: logger.info( - 'No Images in eval dataset, evaluation during training will be disabled' + 'No Images in eval dataset, evaluation during training ' \ + 'will be disabled' ) start_eval_step = 1e111 logger.info( - "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + "During the training process, after the {}th iteration, " \ + "an evaluation is run every {} iterations". format(start_eval_step, eval_batch_step)) save_epoch_step = config['Global']['save_epoch_step'] save_model_dir = config['Global']['save_model_dir'] @@ -189,10 +192,11 @@ def train(config, start_epoch = best_model_dict[ 'start_epoch'] if 'start_epoch' in best_model_dict else 1 - train_reader_cost = 0.0 - train_run_cost = 0.0 total_samples = 0 + train_reader_cost = 0.0 + train_batch_cost = 0.0 reader_start = time.time() + eta_meter = AverageMeter() max_iter = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) @@ -203,7 +207,6 @@ def train(config, config, 'Train', device, logger, seed=epoch) max_iter = len(train_dataloader) - 1 if platform.system( ) == "Windows" else len(train_dataloader) - for idx, batch in enumerate(train_dataloader): profiler.add_profiler_step(profiler_options) train_reader_cost += time.time() - reader_start @@ -214,7 +217,6 @@ def train(config, if use_srn: model_average = True - train_start = time.time() # use amp if scaler: with paddle.amp.auto_cast(): @@ -242,7 +244,9 @@ def train(config, optimizer.step() optimizer.clear_grad() - train_run_cost += time.time() - train_start + train_batch_time = time.time() - reader_start + train_batch_cost += train_batch_time + eta_meter.update(train_batch_time) global_step += 1 total_samples += len(images) @@ -273,19 +277,26 @@ def train(config, (global_step > 0 and global_step % print_batch_step == 0) or (idx >= len(train_dataloader) - 1)): logs = train_stats.log() - strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format( - epoch, epoch_num, global_step, logs, train_reader_cost / - print_batch_step, (train_reader_cost + train_run_cost) / - print_batch_step, total_samples / print_batch_step, - total_samples / (train_reader_cost + train_run_cost)) + eta_sec = ((epoch_num + 1 - epoch) * \ + len(train_dataloader) - idx - 1) * eta_meter.avg + eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec))) + strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \ + '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \ + 'ips: {:.5f}, eta: {}'.format( + epoch, epoch_num, global_step, logs, + train_reader_cost / print_batch_step, + train_batch_cost / print_batch_step, + total_samples / print_batch_step, + total_samples / train_batch_cost, eta_sec_format) logger.info(strs) - train_reader_cost = 0.0 - train_run_cost = 0.0 total_samples = 0 + train_reader_cost = 0.0 + train_batch_cost = 0.0 # eval if global_step > start_eval_step and \ - (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: + (global_step - start_eval_step) % eval_batch_step == 0 \ + and dist.get_rank() == 0: if model_average: Model_Average = paddle.incubate.optimizer.ModelAverage( 0.15, -- GitLab