提交 a7aa87a6 编写于 作者: S seven

remove unused args

上级 857300d4
...@@ -33,21 +33,15 @@ class Interpretation(object): ...@@ -33,21 +33,15 @@ class Interpretation(object):
self.algorithm = supported_algorithms[self.algorithm_name]( self.algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, label_names, **kwargs) self.predict_fn, label_names, **kwargs)
def interpret(self, def interpret(self, data_, visualization=True, save_dir='./'):
data_,
visualization=True,
save_to_disk=True,
save_dir='./tmp'):
""" """
Args: Args:
data_: data_ can be a path or numpy.ndarray. data_: data_ can be a path or numpy.ndarray.
visualization: whether to show using matplotlib. visualization: whether to show using matplotlib.
save_to_disk: whether to save the figure in local disk.
save_dir: dir to save figure if save_to_disk is True. save_dir: dir to save figure if save_to_disk is True.
Returns: Returns:
""" """
return self.algorithm.interpret(data_, visualization, save_to_disk, return self.algorithm.interpret(data_, visualization, save_dir)
save_dir)
...@@ -76,16 +76,12 @@ class CAM(object): ...@@ -76,16 +76,12 @@ class CAM(object):
ln, prob_str)) ln, prob_str))
return feature_maps, fc_weights return feature_maps, fc_weights
def interpret(self, def interpret(self, data_, visualization=True, save_outdir=None):
data_,
visualization=True,
save_to_disk=True,
save_outdir=None):
feature_maps, fc_weights = self.preparation_cam(data_) feature_maps, fc_weights = self.preparation_cam(data_)
cam = get_cam(self.image, feature_maps, fc_weights, cam = get_cam(self.image, feature_maps, fc_weights,
self.predicted_label) self.predicted_label)
if visualization or save_to_disk: if visualization or save_outdir is not None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
...@@ -110,7 +106,7 @@ class CAM(object): ...@@ -110,7 +106,7 @@ class CAM(object):
axes[1].imshow(cam) axes[1].imshow(cam)
axes[1].set_title("CAM") axes[1].set_title("CAM")
if save_to_disk and save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'cam') save_fig(data_, save_outdir, 'cam')
...@@ -186,15 +182,11 @@ class LIME(object): ...@@ -186,15 +182,11 @@ class LIME(object):
self.lime_interpreter = interpreter self.lime_interpreter = interpreter
logging.info('lime time: ' + str(time.time() - end) + 's.') logging.info('lime time: ' + str(time.time() - end) + 's.')
def interpret(self, def interpret(self, data_, visualization=True, save_outdir=None):
data_,
visualization=True,
save_to_disk=True,
save_outdir=None):
if self.lime_interpreter is None: if self.lime_interpreter is None:
self.preparation_lime(data_) self.preparation_lime(data_)
if visualization or save_to_disk: if visualization or save_outdir is not None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
...@@ -234,7 +226,7 @@ class LIME(object): ...@@ -234,7 +226,7 @@ class LIME(object):
axes[ncols + i].set_title( axes[ncols + i].set_title(
"label {}, first {} superpixels".format(ln, num_to_show)) "label {}, first {} superpixels".format(ln, num_to_show))
if save_to_disk and save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'lime', self.num_samples) save_fig(data_, save_outdir, 'lime', self.num_samples)
...@@ -337,11 +329,7 @@ class NormLIMEStandard(object): ...@@ -337,11 +329,7 @@ class NormLIMEStandard(object):
return g_weights return g_weights
def interpret(self, def interpret(self, data_, visualization=True, save_outdir=None):
data_,
visualization=True,
save_to_disk=True,
save_outdir=None):
if self.normlime_weights is None: if self.normlime_weights is None:
raise ValueError( raise ValueError(
"Not find the correct precomputed NormLIME result. \n" "Not find the correct precomputed NormLIME result. \n"
...@@ -351,7 +339,7 @@ class NormLIMEStandard(object): ...@@ -351,7 +339,7 @@ class NormLIMEStandard(object):
g_weights = self.preparation_normlime(data_) g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_interpreter.local_weights lime_weights = self._lime.lime_interpreter.local_weights
if visualization or save_to_disk: if visualization or save_outdir is not None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
...@@ -423,7 +411,7 @@ class NormLIMEStandard(object): ...@@ -423,7 +411,7 @@ class NormLIMEStandard(object):
self._lime.lime_interpreter.local_weights = lime_weights self._lime.lime_interpreter.local_weights = lime_weights
if save_to_disk and save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples) save_fig(data_, save_outdir, 'normlime', self.num_samples)
...@@ -524,11 +512,7 @@ class NormLIME(object): ...@@ -524,11 +512,7 @@ class NormLIME(object):
return g_weights return g_weights
def interpret(self, def interpret(self, data_, visualization=True, save_outdir=None):
data_,
visualization=True,
save_to_disk=True,
save_outdir=None):
if self.normlime_weights is None: if self.normlime_weights is None:
raise ValueError( raise ValueError(
"Not find the correct precomputed NormLIME result. \n" "Not find the correct precomputed NormLIME result. \n"
...@@ -538,7 +522,7 @@ class NormLIME(object): ...@@ -538,7 +522,7 @@ class NormLIME(object):
g_weights = self.preparation_normlime(data_) g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_interpreter.local_weights lime_weights = self._lime.lime_interpreter.local_weights
if visualization or save_to_disk: if visualization or save_outdir is not None:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
...@@ -611,7 +595,7 @@ class NormLIME(object): ...@@ -611,7 +595,7 @@ class NormLIME(object):
self._lime.lime_interpreter.local_weights = lime_weights self._lime.lime_interpreter.local_weights = lime_weights
if save_to_disk and save_outdir is not None: if save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples) save_fig(data_, save_outdir, 'normlime', self.num_samples)
......
...@@ -168,7 +168,7 @@ def get_normlime_interpreter(img, ...@@ -168,7 +168,7 @@ def get_normlime_interpreter(img,
normlime_weights_file = precompute_global_classifier( normlime_weights_file = precompute_global_classifier(
dataset, dataset,
predict_func, predict_func,
save_path=normlime_weights_file, save_path=osp.join(save_dir, normlime_weights_file),
batch_size=batch_size) batch_size=batch_size)
interpreter = Interpretation( interpreter = Interpretation(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册