diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index af999a2b419792d7f11e97d655f36feafc71450f..3264c8d3b0d9830c679bb02cf2124de8e67affa2 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -375,6 +375,8 @@ class BaseAPI: use_vdl=False, early_stop=False, early_stop_patience=5): + if train_dataset.num_samples < train_batch_size: + raise Exception('The amount of training datset must be larger than batch size.') if not osp.isdir(save_dir): if osp.exists(save_dir): os.remove(save_dir) diff --git a/paddlex/cv/models/slim/visualize.py b/paddlex/cv/models/slim/visualize.py index d9380abb2f1184cfe59d77b84d6841b5c4fd7288..083b177ea8c8878562070df3b617b32248046fea 100644 --- a/paddlex/cv/models/slim/visualize.py +++ b/paddlex/cv/models/slim/visualize.py @@ -30,7 +30,6 @@ def visualize(model, sensitivities_file, save_dir='./'): import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt - program = model.test_prog place = model.places[0] fig = plt.figure() @@ -51,15 +50,21 @@ def visualize(model, sensitivities_file, save_dir='./'): min(np.array(x)) - 0.01, max(np.array(x)) + 0.01, 0.05) my_y_ticks = np.arange(0.05, 1, 0.05) - plt.xticks(my_x_ticks, fontsize=3) - plt.yticks(my_y_ticks, fontsize=3) + plt.xticks(my_x_ticks, rotation=30, fontsize=8) + plt.yticks(my_y_ticks, fontsize=8) for a, b in zip(x, y): plt.text( a, - b, (float('%0.4f' % a), float('%0.3f' % b)), + b, (float('%0.3f' % a), float('%0.3f' % b)), ha='center', va='bottom', - fontsize=3) + fontsize=8) + plt.rcParams['savefig.dpi'] = 120 + plt.rcParams['figure.dpi'] = 150 suffix = osp.splitext(sensitivities_file)[-1] - plt.savefig('sensitivities.png', dpi=800) + plt.savefig(osp.join(save_dir, 'sensitivities.png')) plt.close() + import pickle + coor = dict(zip(x, y)) + output = open(osp.join(save_dir, 'sensitivities_xy.pkl'), 'wb') + pickle.dump(coor, output)