提交 1ee6d162 编写于 作者: S sunyanfang01

fis the post quant

上级 0b9a4c4c
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
import sys
import numpy as np import numpy as np
import time import time
import math import math
...@@ -252,6 +253,9 @@ class BaseAPI: ...@@ -252,6 +253,9 @@ class BaseAPI:
del self.init_params['self'] del self.init_params['self']
if '__class__' in self.init_params: if '__class__' in self.init_params:
del self.init_params['__class__'] del self.init_params['__class__']
if 'model_name' in self.init_params:
del self.init_params['model_name']
info['_init_params'] = self.init_params info['_init_params'] = self.init_params
info['_Attributes']['num_classes'] = self.num_classes info['_Attributes']['num_classes'] = self.num_classes
...@@ -372,6 +376,8 @@ class BaseAPI: ...@@ -372,6 +376,8 @@ class BaseAPI:
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
...@@ -429,9 +435,7 @@ class BaseAPI: ...@@ -429,9 +435,7 @@ class BaseAPI:
if use_vdl: if use_vdl:
# VisualDL component # VisualDL component
log_writer = LogWriter(vdl_logdir, sync_cycle=20) log_writer = LogWriter(vdl_logdir)
train_step_component = OrderedDict()
eval_component = OrderedDict()
thresh = 0.0001 thresh = 0.0001
if early_stop: if early_stop:
...@@ -469,13 +473,7 @@ class BaseAPI: ...@@ -469,13 +473,7 @@ class BaseAPI:
if use_vdl: if use_vdl:
for k, v in step_metrics.items(): for k, v in step_metrics.items():
if k not in train_step_component.keys(): log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
with log_writer.mode('Each_Step_while_Training'
) as step_logger:
train_step_component[
k] = step_logger.scalar(
'Training: {}'.format(k))
train_step_component[k].add_record(num_steps, v)
# 估算剩余时间 # 估算剩余时间
avg_step_time = np.mean(time_stat) avg_step_time = np.mean(time_stat)
...@@ -536,12 +534,7 @@ class BaseAPI: ...@@ -536,12 +534,7 @@ class BaseAPI:
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
if v.size > 1: if v.size > 1:
continue continue
if k not in eval_component: log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
with log_writer.mode('Each_Epoch_on_Eval_Data'
) as eval_logger:
eval_component[k] = eval_logger.scalar(
'Evaluation: {}'.format(k))
eval_component[k].add_record(i + 1, v)
self.save_model(save_dir=current_save_dir) self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time() eval_epoch_start_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册