未验证 提交 aaae4958 编写于 作者: O OneYearIsEnough 提交者: GitHub

[Feature] Add eta function in model's training stage (#5380)

* [Feature] Add eta function in model's training stage

* [Feature] Add eta function in model's training stage

* [Feature] Add eta function in model's training stage

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [BugFix] Fix offset bug, residual idxes should -1
上级 b53483db
...@@ -105,3 +105,22 @@ def set_seed(seed=1024): ...@@ -105,3 +105,22 @@ def set_seed(seed=1024):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
paddle.seed(seed) paddle.seed(seed)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
"""reset"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""update"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
...@@ -21,7 +21,7 @@ import sys ...@@ -21,7 +21,7 @@ import sys
import platform import platform
import yaml import yaml
import time import time
import shutil import datetime
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from tqdm import tqdm from tqdm import tqdm
...@@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter ...@@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils import profiler from ppocr.utils import profiler
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser): class ArgsParser(ArgumentParser):
...@@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser): ...@@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser):
'--profiler_options', '--profiler_options',
type=str, type=str,
default=None, default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".' help='The option of profiler, which should be in format ' \
'\"key1=value1;key2=value2;key3=value3\".'
) )
def parse_args(self, argv=None): def parse_args(self, argv=None):
...@@ -99,7 +99,8 @@ def merge_config(config, opts): ...@@ -99,7 +99,8 @@ def merge_config(config, opts):
sub_keys = key.split('.') sub_keys = key.split('.')
assert ( assert (
sub_keys[0] in config sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( ), "the sub_keys can only be one of global_config: {}, but get: " \
"{}, please check your running command".format(
config.keys(), sub_keys[0]) config.keys(), sub_keys[0])
cur = config[sub_keys[0]] cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
...@@ -160,11 +161,13 @@ def train(config, ...@@ -160,11 +161,13 @@ def train(config,
eval_batch_step = eval_batch_step[1] eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0: if len(valid_dataloader) == 0:
logger.info( logger.info(
'No Images in eval dataset, evaluation during training will be disabled' 'No Images in eval dataset, evaluation during training ' \
'will be disabled'
) )
start_eval_step = 1e111 start_eval_step = 1e111
logger.info( logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations". "During the training process, after the {}th iteration, " \
"an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step)) format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step'] save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir'] save_model_dir = config['Global']['save_model_dir']
...@@ -189,10 +192,11 @@ def train(config, ...@@ -189,10 +192,11 @@ def train(config,
start_epoch = best_model_dict[ start_epoch = best_model_dict[
'start_epoch'] if 'start_epoch' in best_model_dict else 1 'start_epoch'] if 'start_epoch' in best_model_dict else 1
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0 total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time() reader_start = time.time()
eta_meter = AverageMeter()
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
...@@ -203,7 +207,6 @@ def train(config, ...@@ -203,7 +207,6 @@ def train(config,
config, 'Train', device, logger, seed=epoch) config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system( max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader) ) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options) profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start train_reader_cost += time.time() - reader_start
...@@ -214,7 +217,6 @@ def train(config, ...@@ -214,7 +217,6 @@ def train(config,
if use_srn: if use_srn:
model_average = True model_average = True
train_start = time.time()
# use amp # use amp
if scaler: if scaler:
with paddle.amp.auto_cast(): with paddle.amp.auto_cast():
...@@ -242,7 +244,9 @@ def train(config, ...@@ -242,7 +244,9 @@ def train(config,
optimizer.step() optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
train_run_cost += time.time() - train_start train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
eta_meter.update(train_batch_time)
global_step += 1 global_step += 1
total_samples += len(images) total_samples += len(images)
...@@ -273,19 +277,26 @@ def train(config, ...@@ -273,19 +277,26 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or (global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)): (idx >= len(train_dataloader) - 1)):
logs = train_stats.log() logs = train_stats.log()
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format( eta_sec = ((epoch_num + 1 - epoch) * \
epoch, epoch_num, global_step, logs, train_reader_cost / len(train_dataloader) - idx - 1) * eta_meter.avg
print_batch_step, (train_reader_cost + train_run_cost) / eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
print_batch_step, total_samples / print_batch_step, strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
total_samples / (train_reader_cost + train_run_cost)) '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f}, eta: {}'.format(
epoch, epoch_num, global_step, logs,
train_reader_cost / print_batch_step,
train_batch_cost / print_batch_step,
total_samples / print_batch_step,
total_samples / train_batch_cost, eta_sec_format)
logger.info(strs) logger.info(strs)
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0 total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
# eval # eval
if global_step > start_eval_step and \ if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: (global_step - start_eval_step) % eval_batch_step == 0 \
and dist.get_rank() == 0:
if model_average: if model_average:
Model_Average = paddle.incubate.optimizer.ModelAverage( Model_Average = paddle.incubate.optimizer.ModelAverage(
0.15, 0.15,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册