提交 43046bfa 编写于 作者: S seven

modify save file name

上级 3fef7f9a
...@@ -107,7 +107,6 @@ class CAM(object): ...@@ -107,7 +107,6 @@ class CAM(object):
axes[1].set_title("CAM") axes[1].set_title("CAM")
if save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'cam') save_fig(data_, save_outdir, 'cam')
if visualization: if visualization:
...@@ -219,7 +218,7 @@ class LIME(object): ...@@ -219,7 +218,7 @@ class LIME(object):
self.lime_interpreter, l, w) self.lime_interpreter, l, w)
temp, mask = self.lime_interpreter.get_image_and_mask( temp, mask = self.lime_interpreter.get_image_and_mask(
l, l,
positive_only=False, positive_only=True,
hide_rest=False, hide_rest=False,
num_features=num_to_show) num_features=num_to_show)
axes[ncols + i].imshow(mark_boundaries(temp, mask)) axes[ncols + i].imshow(mark_boundaries(temp, mask))
...@@ -227,7 +226,6 @@ class LIME(object): ...@@ -227,7 +226,6 @@ class LIME(object):
"label {}, first {} superpixels".format(ln, num_to_show)) "label {}, first {} superpixels".format(ln, num_to_show))
if save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'lime', self.num_samples) save_fig(data_, save_outdir, 'lime', self.num_samples)
if visualization: if visualization:
...@@ -412,7 +410,6 @@ class NormLIMEStandard(object): ...@@ -412,7 +410,6 @@ class NormLIMEStandard(object):
self._lime.lime_interpreter.local_weights = lime_weights self._lime.lime_interpreter.local_weights = lime_weights
if save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples) save_fig(data_, save_outdir, 'normlime', self.num_samples)
if visualization: if visualization:
...@@ -596,7 +593,6 @@ class NormLIME(object): ...@@ -596,7 +593,6 @@ class NormLIME(object):
self._lime.lime_interpreter.local_weights = lime_weights self._lime.lime_interpreter.local_weights = lime_weights
if save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples) save_fig(data_, save_outdir, 'normlime', self.num_samples)
if visualization: if visualization:
...@@ -674,26 +670,11 @@ def get_cam(image_show, ...@@ -674,26 +670,11 @@ def get_cam(image_show,
def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if isinstance(data_, str): if algorithm_name == 'cam':
if algorithm_name == 'cam': f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1])
else:
f_out = "{}_{}_s{}.png".format(algorithm_name,
data_.split('/')[-1], num_samples)
plt.savefig(os.path.join(save_outdir, f_out))
else: else:
n = 0 f_out = "{}_{}_s{}.png".format(save_outdir, algorithm_name,
if algorithm_name == 'cam': num_samples)
f_out = 'cam-{}.png'.format(n)
else: plt.savefig(f_out)
f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n) logging.info('The image of intrepretation result save in {}'.format(f_out))
while os.path.exists(os.path.join(save_outdir, f_out)):
n += 1
if algorithm_name == 'cam':
f_out = 'cam-{}.png'.format(n)
else:
f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n)
continue
plt.savefig(os.path.join(save_outdir, f_out))
logging.info('The image of intrepretation result save in {}'.format(
os.path.join(save_outdir, f_out)))
...@@ -58,7 +58,7 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): ...@@ -58,7 +58,7 @@ def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'):
interpreter = get_lime_interpreter( interpreter = get_lime_interpreter(
img, model, num_samples=num_samples, batch_size=batch_size) img, model, num_samples=num_samples, batch_size=batch_size)
img_name = osp.splitext(osp.split(img_file)[-1])[0] img_name = osp.splitext(osp.split(img_file)[-1])[0]
interpreter.interpret(img, save_dir=save_dir) interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
def normlime(img_file, def normlime(img_file,
...@@ -111,7 +111,7 @@ def normlime(img_file, ...@@ -111,7 +111,7 @@ def normlime(img_file,
save_dir=save_dir, save_dir=save_dir,
normlime_weights_file=normlime_weights_file) normlime_weights_file=normlime_weights_file)
img_name = osp.splitext(osp.split(img_file)[-1])[0] img_name = osp.splitext(osp.split(img_file)[-1])[0]
interpreter.interpret(img, save_dir=save_dir) interpreter.interpret(img, save_dir=osp.join(save_dir, img_name))
def get_lime_interpreter(img, model, num_samples=3000, batch_size=50): def get_lime_interpreter(img, model, num_samples=3000, batch_size=50):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册