提交 120db901 编写于 作者: S sunyanfang01

modify the base and vis

上级 2484756a
...@@ -371,6 +371,8 @@ class BaseAPI: ...@@ -371,6 +371,8 @@ class BaseAPI:
use_vdl=False, use_vdl=False,
early_stop=False, early_stop=False,
early_stop_patience=5): early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir): if not osp.isdir(save_dir):
if osp.exists(save_dir): if osp.exists(save_dir):
os.remove(save_dir) os.remove(save_dir)
......
...@@ -30,7 +30,6 @@ def visualize(model, sensitivities_file, save_dir='./'): ...@@ -30,7 +30,6 @@ def visualize(model, sensitivities_file, save_dir='./'):
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
program = model.test_prog program = model.test_prog
place = model.places[0] place = model.places[0]
fig = plt.figure() fig = plt.figure()
...@@ -51,15 +50,21 @@ def visualize(model, sensitivities_file, save_dir='./'): ...@@ -51,15 +50,21 @@ def visualize(model, sensitivities_file, save_dir='./'):
min(np.array(x)) - 0.01, min(np.array(x)) - 0.01,
max(np.array(x)) + 0.01, 0.05) max(np.array(x)) + 0.01, 0.05)
my_y_ticks = np.arange(0.05, 1, 0.05) my_y_ticks = np.arange(0.05, 1, 0.05)
plt.xticks(my_x_ticks, fontsize=3) plt.xticks(my_x_ticks, rotation=30, fontsize=8)
plt.yticks(my_y_ticks, fontsize=3) plt.yticks(my_y_ticks, fontsize=8)
for a, b in zip(x, y): for a, b in zip(x, y):
plt.text( plt.text(
a, a,
b, (float('%0.4f' % a), float('%0.3f' % b)), b, (float('%0.3f' % a), float('%0.3f' % b)),
ha='center', ha='center',
va='bottom', va='bottom',
fontsize=3) fontsize=8)
plt.rcParams['savefig.dpi'] = 120
plt.rcParams['figure.dpi'] = 150
suffix = osp.splitext(sensitivities_file)[-1] suffix = osp.splitext(sensitivities_file)[-1]
plt.savefig('sensitivities.png', dpi=800) plt.savefig(osp.join(save_dir, 'sensitivities.png'))
plt.close() plt.close()
import pickle
coor = dict(zip(x, y))
output = open(osp.join(save_dir, 'sensitivities_xy.pkl'), 'wb')
pickle.dump(coor, output)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册