未验证 提交 aaae4958 编写于 作者: O OneYearIsEnough 提交者: GitHub

[Feature] Add eta function in model's training stage (#5380)

* [Feature] Add eta function in model's training stage

* [Feature] Add eta function in model's training stage

* [Feature] Add eta function in model's training stage

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [Feature] Adjust the strategy of ETA function according to Donkey's smart proposals.

* [BugFix] Fix offset bug, residual idxes should -1
上级 b53483db
......@@ -105,3 +105,22 @@ def set_seed(seed=1024):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
class AverageMeter:
def __init__(self):
self.reset()
def reset(self):
"""reset"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""update"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
......@@ -21,7 +21,7 @@ import sys
import platform
import yaml
import time
import shutil
import datetime
import paddle
import paddle.distributed as dist
from tqdm import tqdm
......@@ -29,11 +29,10 @@ from argparse import ArgumentParser, RawDescriptionHelpFormatter
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.utility import print_dict, AverageMeter
from ppocr.utils.logging import get_logger
from ppocr.utils import profiler
from ppocr.data import build_dataloader
import numpy as np
class ArgsParser(ArgumentParser):
......@@ -48,7 +47,8 @@ class ArgsParser(ArgumentParser):
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
help='The option of profiler, which should be in format ' \
'\"key1=value1;key2=value2;key3=value3\".'
)
def parse_args(self, argv=None):
......@@ -99,7 +99,8 @@ def merge_config(config, opts):
sub_keys = key.split('.')
assert (
sub_keys[0] in config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
), "the sub_keys can only be one of global_config: {}, but get: " \
"{}, please check your running command".format(
config.keys(), sub_keys[0])
cur = config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
......@@ -160,11 +161,13 @@ def train(config,
eval_batch_step = eval_batch_step[1]
if len(valid_dataloader) == 0:
logger.info(
'No Images in eval dataset, evaluation during training will be disabled'
'No Images in eval dataset, evaluation during training ' \
'will be disabled'
)
start_eval_step = 1e111
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
"During the training process, after the {}th iteration, " \
"an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir']
......@@ -189,10 +192,11 @@ def train(config,
start_epoch = best_model_dict[
'start_epoch'] if 'start_epoch' in best_model_dict else 1
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time()
eta_meter = AverageMeter()
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
......@@ -203,7 +207,6 @@ def train(config,
config, 'Train', device, logger, seed=epoch)
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - reader_start
......@@ -214,7 +217,6 @@ def train(config,
if use_srn:
model_average = True
train_start = time.time()
# use amp
if scaler:
with paddle.amp.auto_cast():
......@@ -242,7 +244,9 @@ def train(config,
optimizer.step()
optimizer.clear_grad()
train_run_cost += time.time() - train_start
train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
eta_meter.update(train_batch_time)
global_step += 1
total_samples += len(images)
......@@ -273,19 +277,26 @@ def train(config,
(global_step > 0 and global_step % print_batch_step == 0) or
(idx >= len(train_dataloader) - 1)):
logs = train_stats.log()
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ips: {:.5f}'.format(
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, (train_reader_cost + train_run_cost) /
print_batch_step, total_samples / print_batch_step,
total_samples / (train_reader_cost + train_run_cost))
eta_sec = ((epoch_num + 1 - epoch) * \
len(train_dataloader) - idx - 1) * eta_meter.avg
eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
'{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
'ips: {:.5f}, eta: {}'.format(
epoch, epoch_num, global_step, logs,
train_reader_cost / print_batch_step,
train_batch_cost / print_batch_step,
total_samples / print_batch_step,
total_samples / train_batch_cost, eta_sec_format)
logger.info(strs)
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
(global_step - start_eval_step) % eval_batch_step == 0 \
and dist.get_rank() == 0:
if model_average:
Model_Average = paddle.incubate.optimizer.ModelAverage(
0.15,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册