未验证 提交 f7d9b7fb 编写于 作者: J Jason 提交者: GitHub

Merge pull request #63 from SunAhong1993/syf_slim1

修改base.py和裁剪可视化
......@@ -375,6 +375,8 @@ class BaseAPI:
use_vdl=False,
early_stop=False,
early_stop_patience=5):
if train_dataset.num_samples < train_batch_size:
raise Exception('The amount of training datset must be larger than batch size.')
if not osp.isdir(save_dir):
if osp.exists(save_dir):
os.remove(save_dir)
......
......@@ -30,7 +30,6 @@ def visualize(model, sensitivities_file, save_dir='./'):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
program = model.test_prog
place = model.places[0]
fig = plt.figure()
......@@ -51,15 +50,21 @@ def visualize(model, sensitivities_file, save_dir='./'):
min(np.array(x)) - 0.01,
max(np.array(x)) + 0.01, 0.05)
my_y_ticks = np.arange(0.05, 1, 0.05)
plt.xticks(my_x_ticks, fontsize=3)
plt.yticks(my_y_ticks, fontsize=3)
plt.xticks(my_x_ticks, rotation=30, fontsize=8)
plt.yticks(my_y_ticks, fontsize=8)
for a, b in zip(x, y):
plt.text(
a,
b, (float('%0.4f' % a), float('%0.3f' % b)),
b, (float('%0.3f' % a), float('%0.3f' % b)),
ha='center',
va='bottom',
fontsize=3)
fontsize=8)
plt.rcParams['savefig.dpi'] = 120
plt.rcParams['figure.dpi'] = 150
suffix = osp.splitext(sensitivities_file)[-1]
plt.savefig('sensitivities.png', dpi=800)
plt.savefig(osp.join(save_dir, 'sensitivities.png'))
plt.close()
import pickle
coor = dict(zip(x, y))
output = open(osp.join(save_dir, 'sensitivities_xy.pkl'), 'wb')
pickle.dump(coor, output)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册