diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index 938df3404156d5b1fd44c20517eb996bbdf4cf27..9d66df7009ff4daf09112b4709e30c39eb38ab67 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import paddle.fluid as fluid import os +import sys import numpy as np import time import math @@ -252,6 +253,9 @@ class BaseAPI: del self.init_params['self'] if '__class__' in self.init_params: del self.init_params['__class__'] + if 'model_name' in self.init_params: + del self.init_params['model_name'] + info['_init_params'] = self.init_params info['_Attributes']['num_classes'] = self.num_classes @@ -372,6 +376,8 @@ class BaseAPI: use_vdl=False, early_stop=False, early_stop_patience=5): + if train_dataset.num_samples < train_batch_size: + raise Exception('The amount of training datset must be larger than batch size.') if not osp.isdir(save_dir): if osp.exists(save_dir): os.remove(save_dir) @@ -429,9 +435,7 @@ class BaseAPI: if use_vdl: # VisualDL component - log_writer = LogWriter(vdl_logdir, sync_cycle=20) - train_step_component = OrderedDict() - eval_component = OrderedDict() + log_writer = LogWriter(vdl_logdir) thresh = 0.0001 if early_stop: @@ -469,13 +473,7 @@ class BaseAPI: if use_vdl: for k, v in step_metrics.items(): - if k not in train_step_component.keys(): - with log_writer.mode('Each_Step_while_Training' - ) as step_logger: - train_step_component[ - k] = step_logger.scalar( - 'Training: {}'.format(k)) - train_step_component[k].add_record(num_steps, v) + log_writer.add_scalar('Metrics/Training(Step): {}'.format(k), v, num_steps) # 估算剩余时间 avg_step_time = np.mean(time_stat) @@ -536,12 +534,7 @@ class BaseAPI: if isinstance(v, np.ndarray): if v.size > 1: continue - if k not in eval_component: - with log_writer.mode('Each_Epoch_on_Eval_Data' - ) as eval_logger: - eval_component[k] = eval_logger.scalar( - 'Evaluation: {}'.format(k)) - eval_component[k].add_record(i + 1, v) + log_writer.add_scalar("Metrics/Eval(Epoch): {}".format(k), v, i+1) self.save_model(save_dir=current_save_dir) time_eval_one_epoch = time.time() - eval_epoch_start_time eval_epoch_start_time = time.time() @@ -552,4 +545,4 @@ class BaseAPI: best_accuracy)) if eval_dataset is not None and early_stop: if earlystop(current_accuracy): - break + break \ No newline at end of file