提交 b885a6d1 编写于 作者: J jiangjiajun

revert base.py

上级 a532ecc2
...@@ -79,9 +79,9 @@ class BaseAPI: ...@@ -79,9 +79,9 @@ class BaseAPI:
return int(batch_size // len(self.places)) return int(batch_size // len(self.places))
else: else:
raise Exception("Please support correct batch_size, \ raise Exception("Please support correct batch_size, \
which can be divided by available cards({}) in {}" which can be divided by available cards({}) in {}".
.format(paddlex.env_info['num'], paddlex.env_info[ format(paddlex.env_info['num'],
'place'])) paddlex.env_info['place']))
def build_program(self): def build_program(self):
# 构建训练网络 # 构建训练网络
...@@ -141,7 +141,7 @@ class BaseAPI: ...@@ -141,7 +141,7 @@ class BaseAPI:
from .slim.post_quantization import PaddleXPostTrainingQuantization from .slim.post_quantization import PaddleXPostTrainingQuantization
except: except:
raise Exception( raise Exception(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0" "Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
) )
is_use_cache_file = True is_use_cache_file = True
if cache_dir is None: if cache_dir is None:
...@@ -209,8 +209,8 @@ class BaseAPI: ...@@ -209,8 +209,8 @@ class BaseAPI:
paddlex.utils.utils.load_pretrain_weights( paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, resume_checkpoint, resume=True) self.exe, self.train_prog, resume_checkpoint, resume=True)
if not osp.exists(osp.join(resume_checkpoint, "model.yml")): if not osp.exists(osp.join(resume_checkpoint, "model.yml")):
raise Exception("There's not model.yml in {}".format( raise Exception(
resume_checkpoint)) "There's not model.yml in {}".format(resume_checkpoint))
with open(osp.join(resume_checkpoint, "model.yml")) as f: with open(osp.join(resume_checkpoint, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader) info = yaml.load(f.read(), Loader=yaml.Loader)
self.completed_epochs = info['completed_epochs'] self.completed_epochs = info['completed_epochs']
...@@ -361,8 +361,8 @@ class BaseAPI: ...@@ -361,8 +361,8 @@ class BaseAPI:
# 模型保存成功的标志 # 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close() open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format( logging.info(
save_dir)) "Model for inference deploy saved in {}.".format(save_dir))
def train_loop(self, def train_loop(self,
num_epochs, num_epochs,
...@@ -376,8 +376,7 @@ class BaseAPI: ...@@ -376,8 +376,7 @@ class BaseAPI:
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5):
if train_dataset.num_samples < train_batch_size: if train_dataset.num_samples < train_batch_size:
raise Exception( raise Exception('The amount of training datset must be larger than batch size.')
'The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
...@@ -415,8 +414,8 @@ class BaseAPI: ...@@ -415,8 +414,8 @@ class BaseAPI:
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
total_num_steps = math.floor(train_dataset.num_samples / total_num_steps = math.floor(
train_batch_size) train_dataset.num_samples / train_batch_size)
num_steps = 0 num_steps = 0
time_stat = list() time_stat = list()
time_train_one_epoch = None time_train_one_epoch = None
...@@ -430,8 +429,8 @@ class BaseAPI: ...@@ -430,8 +429,8 @@ class BaseAPI:
if self.model_type == 'detector': if self.model_type == 'detector':
eval_batch_size = self._get_single_card_bs(train_batch_size) eval_batch_size = self._get_single_card_bs(train_batch_size)
if eval_dataset is not None: if eval_dataset is not None:
total_num_steps_eval = math.ceil(eval_dataset.num_samples / total_num_steps_eval = math.ceil(
eval_batch_size) eval_dataset.num_samples / eval_batch_size)
if use_vdl: if use_vdl:
# VisualDL component # VisualDL component
...@@ -473,9 +472,7 @@ class BaseAPI: ...@@ -473,9 +472,7 @@ class BaseAPI:
if use_vdl: if use_vdl:
for k, v in step_metrics.items(): for k, v in step_metrics.items():
log_writer.add_scalar( log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
'Metrics/Training(Step): {}'.format(k), v,
num_steps)
# 估算剩余时间 # 估算剩余时间
avg_step_time = np.mean(time_stat) avg_step_time = np.mean(time_stat)
...@@ -483,12 +480,11 @@ class BaseAPI: ...@@ -483,12 +480,11 @@ class BaseAPI:
eta = (num_epochs - i - 1) * time_train_one_epoch + ( eta = (num_epochs - i - 1) * time_train_one_epoch + (
total_num_steps - step - 1) * avg_step_time total_num_steps - step - 1) * avg_step_time
else: else:
eta = ((num_epochs - i) * total_num_steps - step - 1 eta = ((num_epochs - i) * total_num_steps - step -
) * avg_step_time 1) * avg_step_time
if time_eval_one_epoch is not None: if time_eval_one_epoch is not None:
eval_eta = ( eval_eta = (total_eval_times - i //
total_eval_times - i // save_interval_epochs save_interval_epochs) * time_eval_one_epoch
) * time_eval_one_epoch
else: else:
eval_eta = ( eval_eta = (
total_eval_times - i // save_interval_epochs total_eval_times - i // save_interval_epochs
...@@ -498,11 +494,10 @@ class BaseAPI: ...@@ -498,11 +494,10 @@ class BaseAPI:
logging.info( logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}" "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
.format(i + 1, num_epochs, step + 1, total_num_steps, .format(i + 1, num_epochs, step + 1, total_num_steps,
dict2str(step_metrics), dict2str(step_metrics), round(
round(avg_step_time, 2), eta_str)) avg_step_time, 2), eta_str))
train_metrics = OrderedDict( train_metrics = OrderedDict(
zip(list(self.train_outputs.keys()), np.mean( zip(list(self.train_outputs.keys()), np.mean(records, axis=0)))
records, axis=0)))
logging.info('[TRAIN] Epoch {} finished, {} .'.format( logging.info('[TRAIN] Epoch {} finished, {} .'.format(
i + 1, dict2str(train_metrics))) i + 1, dict2str(train_metrics)))
time_train_one_epoch = time.time() - epoch_start_time time_train_one_epoch = time.time() - epoch_start_time
...@@ -538,8 +533,7 @@ class BaseAPI: ...@@ -538,8 +533,7 @@ class BaseAPI:
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
if v.size > 1: if v.size > 1:
continue continue
log_writer.add_scalar( log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
"Metrics/Eval(Epoch): {}".format(k), v, i + 1)
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time() eval_epoch_start_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册