提交 1ee6d162 编写于 作者: S sunyanfang01

fis the post quant

上级 0b9a4c4c
......@@ -15,6 +15,7 @@
from __future__ import absolute_import
import paddle.fluid as fluid
import os
import sys
import numpy as np
import time
import math
......@@ -252,6 +253,9 @@ class BaseAPI:
del self.init_params['self']
if '__class__' in self.init_params:
del self.init_params['__class__']
if 'model_name' in self.init_params:
del self.init_params['model_name']
info['_init_params'] = self.init_params
info['_Attributes']['num_classes'] = self.num_classes
......@@ -372,6 +376,8 @@ class BaseAPI:
use_vdl=False,
early_stop=False,
early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
......@@ -429,9 +435,7 @@ class BaseAPI:
if use_vdl:
# VisualDL component
log_writer = LogWriter(vdl_logdir, sync_cycle=20)
train_step_component = OrderedDict()
eval_component = OrderedDict()
log_writer = LogWriter(vdl_logdir)
thresh = 0.0001
if early_stop:
......@@ -469,13 +473,7 @@ class BaseAPI:
if use_vdl:
for k, v in step_metrics.items():
if k not in train_step_component.keys():
with log_writer.mode('Each_Step_while_Training'
) as step_logger:
train_step_component[
k] = step_logger.scalar(
'Training: {}'.format(k))
train_step_component[k].add_record(num_steps, v)
log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps)
# 估算剩余时间
avg_step_time = np.mean(time_stat)
......@@ -536,12 +534,7 @@ class BaseAPI:
if isinstance(v, np.ndarray):
if v.size > 1:
continue
if k not in eval_component:
with log_writer.mode('Each_Epoch_on_Eval_Data'
) as eval_logger:
eval_component[k] = eval_logger.scalar(
'Evaluation: {}'.format(k))
eval_component[k].add_record(i + 1, v)
log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1)
self.save_model(save_dir=current_save_dir)
time_eval_one_epoch = time.time() - eval_epoch_start_time
eval_epoch_start_time = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册