From a7aa87a6aeec113d2fb19afc72270967c73efd17 Mon Sep 17 00:00:00 2001 From: seven Date: Tue, 2 Jun 2020 19:41:34 +0800 Subject: [PATCH] remove unused args --- paddlex/interpret/core/interpretation.py | 10 +---- .../core/interpretation_algorithms.py | 40 ++++++------------- paddlex/interpret/visualize.py | 2 +- 3 files changed, 15 insertions(+), 37 deletions(-) diff --git a/paddlex/interpret/core/interpretation.py b/paddlex/interpret/core/interpretation.py index 6000677..5b1a5e4 100644 --- a/paddlex/interpret/core/interpretation.py +++ b/paddlex/interpret/core/interpretation.py @@ -33,21 +33,15 @@ class Interpretation(object): self.algorithm = supported_algorithms[self.algorithm_name]( self.predict_fn, label_names, **kwargs) - def interpret(self, - data_, - visualization=True, - save_to_disk=True, - save_dir='./tmp'): + def interpret(self, data_, visualization=True, save_dir='./'): """ Args: data_: data_ can be a path or numpy.ndarray. visualization: whether to show using matplotlib. - save_to_disk: whether to save the figure in local disk. save_dir: dir to save figure if save_to_disk is True. Returns: """ - return self.algorithm.interpret(data_, visualization, save_to_disk, - save_dir) + return self.algorithm.interpret(data_, visualization, save_dir) diff --git a/paddlex/interpret/core/interpretation_algorithms.py b/paddlex/interpret/core/interpretation_algorithms.py index 3174428..476ece1 100644 --- a/paddlex/interpret/core/interpretation_algorithms.py +++ b/paddlex/interpret/core/interpretation_algorithms.py @@ -76,16 +76,12 @@ class CAM(object): ln, prob_str)) return feature_maps, fc_weights - def interpret(self, - data_, - visualization=True, - save_to_disk=True, - save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): feature_maps, fc_weights = self.preparation_cam(data_) cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label) - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] @@ -110,7 +106,7 @@ class CAM(object): axes[1].imshow(cam) axes[1].set_title("CAM") - if save_to_disk and save_outdir is not None: + if save_outdir is not None: os.makedirs(save_outdir, exist_ok=True) save_fig(data_, save_outdir, 'cam') @@ -186,15 +182,11 @@ class LIME(object): self.lime_interpreter = interpreter logging.info('lime time: ' + str(time.time() - end) + 's.') - def interpret(self, - data_, - visualization=True, - save_to_disk=True, - save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): if self.lime_interpreter is None: self.preparation_lime(data_) - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] @@ -234,7 +226,7 @@ class LIME(object): axes[ncols + i].set_title( "label {}, first {} superpixels".format(ln, num_to_show)) - if save_to_disk and save_outdir is not None: + if save_outdir is not None: os.makedirs(save_outdir, exist_ok=True) save_fig(data_, save_outdir, 'lime', self.num_samples) @@ -337,11 +329,7 @@ class NormLIMEStandard(object): return g_weights - def interpret(self, - data_, - visualization=True, - save_to_disk=True, - save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): if self.normlime_weights is None: raise ValueError( "Not find the correct precomputed NormLIME result. \n" @@ -351,7 +339,7 @@ class NormLIMEStandard(object): g_weights = self.preparation_normlime(data_) lime_weights = self._lime.lime_interpreter.local_weights - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] @@ -423,7 +411,7 @@ class NormLIMEStandard(object): self._lime.lime_interpreter.local_weights = lime_weights - if save_to_disk and save_outdir is not None: + if save_outdir is not None: os.makedirs(save_outdir, exist_ok=True) save_fig(data_, save_outdir, 'normlime', self.num_samples) @@ -524,11 +512,7 @@ class NormLIME(object): return g_weights - def interpret(self, - data_, - visualization=True, - save_to_disk=True, - save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): if self.normlime_weights is None: raise ValueError( "Not find the correct precomputed NormLIME result. \n" @@ -538,7 +522,7 @@ class NormLIME(object): g_weights = self.preparation_normlime(data_) lime_weights = self._lime.lime_interpreter.local_weights - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] @@ -611,7 +595,7 @@ class NormLIME(object): self._lime.lime_interpreter.local_weights = lime_weights - if save_to_disk and save_outdir is not None: + if save_outdir is not None: os.makedirs(save_outdir, exist_ok=True) save_fig(data_, save_outdir, 'normlime', self.num_samples) diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index 084c377..d6785c6 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -168,7 +168,7 @@ def get_normlime_interpreter(img, normlime_weights_file = precompute_global_classifier( dataset, predict_func, - save_path=normlime_weights_file, + save_path=osp.join(save_dir, normlime_weights_file), batch_size=batch_size) interpreter = Interpretation( -- GitLab