提交 b885a6d1 编写于 作者: J jiangjiajun

revert base.py

上级 a532ecc2
......@@ -79,9 +79,9 @@ class BaseAPI:
return int(batch_size // len(self.places))
else:
raise Exception("Please support correct batch_size, \
which can be divided by available cards({}) in {}"
.format(paddlex.env_info['num'], paddlex.env_info[
'place']))
which can be divided by available cards({}) in {}".
format(paddlex.env_info['num'],
paddlex.env_info['place']))
def build_program(self):
# 构建训练网络
......@@ -141,7 +141,7 @@ class BaseAPI:
from .slim.post_quantization import PaddleXPostTrainingQuantization
except:
raise Exception(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
)
is_use_cache_file = True
if cache_dir is None:
......@@ -209,8 +209,8 @@ class BaseAPI:
paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, resume_checkpoint, resume=True)
if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
raise Exception("There's not model.yml in {}".format(
resume_checkpoint))
raise Exception(
"There's not model.yml in {}".format(resume_checkpoint))
with open(osp.join(resume_checkpoint, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader)
self.completed_epochs = info['completed_epochs']
......@@ -361,8 +361,8 @@ class BaseAPI:
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(
save_dir))
logging.info(
"Model for inference deploy saved in {}.".format(save_dir))
def train_loop(self,
num_epochs,
......@@ -376,8 +376,7 @@ class BaseAPI:
early_stop=False,
early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception(
'The amount of training datset must be larger than batch size.')
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
......@@ -415,8 +414,8 @@ class BaseAPI:
build_strategy=build_strategy,
exec_strategy=exec_strategy)
total_num_steps = math.floor(train_dataset.num_samples /
train_batch_size)
total_num_steps = math.floor(
train_dataset.num_samples / train_batch_size)
num_steps = 0
time_stat = list()
time_train_one_epoch = None
......@@ -430,8 +429,8 @@ class BaseAPI:
if self.model_type == 'detector':
eval_batch_size = self._get_single_card_bs(train_batch_size)
if eval_dataset is not None:
total_num_steps_eval = math.ceil(eval_dataset.num_samples /
eval_batch_size)
total_num_steps_eval = math.ceil(
eval_dataset.num_samples / eval_batch_size)
if use_vdl:
# VisualDL component
......@@ -473,9 +472,7 @@ class BaseAPI:
if use_vdl:
for k, v in step_metrics.items():
log_writer.add_scalar(
'Metrics/Training(Step): {}'.format(k), v,
num_steps)
log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
# 估算剩余时间
avg_step_time = np.mean(time_stat)
......@@ -483,12 +480,11 @@ class BaseAPI:
eta = (num_epochs - i - 1) * time_train_one_epoch + (
total_num_steps - step - 1) * avg_step_time
else:
eta = ((num_epochs - i) * total_num_steps - step - 1
) * avg_step_time
eta = ((num_epochs - i) * total_num_steps - step -
1) * avg_step_time
if time_eval_one_epoch is not None:
eval_eta = (
total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
eval_eta = (total_eval_times - i //
save_interval_epochs) * time_eval_one_epoch
else:
eval_eta = (
total_eval_times - i // save_interval_epochs
......@@ -498,11 +494,10 @@ class BaseAPI:
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
.format(i + 1, num_epochs, step + 1, total_num_steps,
dict2str(step_metrics),
round(avg_step_time, 2), eta_str))
dict2str(step_metrics), round(
avg_step_time, 2), eta_str))
train_metrics = OrderedDict(
zip(list(self.train_outputs.keys()), np.mean(
records, axis=0)))
zip(list(self.train_outputs.keys()), np.mean(records, axis=0)))
logging.info('[TRAIN] Epoch {} finished, {} .'.format(
i + 1, dict2str(train_metrics)))
time_train_one_epoch = time.time() - epoch_start_time
......@@ -538,8 +533,7 @@ class BaseAPI:
if isinstance(v, np.ndarray):
if v.size > 1:
continue
log_writer.add_scalar(
"Metrics/Eval(Epoch): {}".format(k), v, i + 1)
log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册