未验证 提交 d51704bd 编写于 作者: S SunAhong1993 提交者: GitHub

Update explanation_algorithms.py

上级 5dc2716d
...@@ -25,7 +25,7 @@ import cv2 ...@@ -25,7 +25,7 @@ import cv2
class CAM(object): class CAM(object):
def __init__(self, predict_fn): def __init__(self, predict_fn, label_names):
""" """
Args: Args:
...@@ -37,6 +37,7 @@ class CAM(object): ...@@ -37,6 +37,7 @@ class CAM(object):
""" """
self.predict_fn = predict_fn self.predict_fn = predict_fn
self.label_names = label_names
def preparation_cam(self, data_): def preparation_cam(self, data_):
image_show = read_image(data_) image_show = read_image(data_)
...@@ -61,8 +62,13 @@ class CAM(object): ...@@ -61,8 +62,13 @@ class CAM(object):
fc_weights = paddle_get_fc_weights() fc_weights = paddle_get_fc_weights()
feature_maps = result[1] feature_maps = result[1]
l = pred_label[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}') print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
return feature_maps, fc_weights return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
...@@ -73,6 +79,9 @@ class CAM(object): ...@@ -73,6 +79,9 @@ class CAM(object):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5 psize = 5
nrows = 1 nrows = 1
...@@ -84,7 +93,7 @@ class CAM(object): ...@@ -84,7 +93,7 @@ class CAM(object):
ax.axis("off") ax.axis("off")
axes = axes.ravel() axes = axes.ravel()
axes[0].imshow(self.image) axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(cam) axes[1].imshow(cam)
axes[1].set_title("CAM") axes[1].set_title("CAM")
...@@ -100,7 +109,7 @@ class CAM(object): ...@@ -100,7 +109,7 @@ class CAM(object):
class LIME(object): class LIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50): def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
""" """
LIME wrapper. See lime_base.py for the detailed LIME implementation. LIME wrapper. See lime_base.py for the detailed LIME implementation.
Args: Args:
...@@ -115,6 +124,7 @@ class LIME(object): ...@@ -115,6 +124,7 @@ class LIME(object):
self.labels = None self.labels = None
self.image = None self.image = None
self.lime_explainer = None self.lime_explainer = None
self.label_names = label_names
def preparation_lime(self, data_): def preparation_lime(self, data_):
image_show = read_image(data_) image_show = read_image(data_)
...@@ -137,8 +147,13 @@ class LIME(object): ...@@ -137,8 +147,13 @@ class LIME(object):
self.predicted_probability = probability[pred_label[0]] self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0] self.image = image_show[0]
self.labels = pred_label self.labels = pred_label
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}') l = pred_label[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
end = time.time() end = time.time()
algo = lime_base.LimeImageExplainer() algo = lime_base.LimeImageExplainer()
...@@ -155,6 +170,9 @@ class LIME(object): ...@@ -155,6 +170,9 @@ class LIME(object):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5 psize = 5
nrows = 2 nrows = 2
...@@ -167,7 +185,7 @@ class LIME(object): ...@@ -167,7 +185,7 @@ class LIME(object):
ax.axis("off") ax.axis("off")
axes = axes.ravel() axes = axes.ravel()
axes[0].imshow(self.image) axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments)) axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
axes[1].set_title("superpixel segmentation") axes[1].set_title("superpixel segmentation")
...@@ -179,7 +197,7 @@ class LIME(object): ...@@ -179,7 +197,7 @@ class LIME(object):
l, positive_only=False, hide_rest=False, num_features=num_to_show l, positive_only=False, hide_rest=False, num_features=num_to_show
) )
axes[ncols + i].imshow(mark_boundaries(temp, mask)) axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels") axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
if save_to_disk and save_outdir is not None: if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True) os.makedirs(save_outdir, exist_ok=True)
...@@ -192,7 +210,7 @@ class LIME(object): ...@@ -192,7 +210,7 @@ class LIME(object):
class NormLIME(object): class NormLIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50, def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
kmeans_model_for_normlime=None, normlime_weights=None): kmeans_model_for_normlime=None, normlime_weights=None):
if kmeans_model_for_normlime is None: if kmeans_model_for_normlime is None:
try: try:
...@@ -218,6 +236,7 @@ class NormLIME(object): ...@@ -218,6 +236,7 @@ class NormLIME(object):
self.labels = None self.labels = None
self.image = None self.image = None
self.label_names = label_names
def predict_cluster_labels(self, feature_map, segments): def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments)) return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
...@@ -239,6 +258,7 @@ class NormLIME(object): ...@@ -239,6 +258,7 @@ class NormLIME(object):
def preparation_normlime(self, data_): def preparation_normlime(self, data_):
self._lime = LIME( self._lime = LIME(
self.predict_fn, self.predict_fn,
self.label_names,
self.num_samples, self.num_samples,
self.batch_size self.batch_size
) )
...@@ -273,6 +293,9 @@ class NormLIME(object): ...@@ -273,6 +293,9 @@ class NormLIME(object):
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries from skimage.segmentation import mark_boundaries
l = self.labels[0] l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5 psize = 5
nrows = 4 nrows = 4
...@@ -287,7 +310,7 @@ class NormLIME(object): ...@@ -287,7 +310,7 @@ class NormLIME(object):
axes = axes.ravel() axes = axes.ravel()
axes[0].imshow(self.image) axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments)) axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
axes[1].set_title("superpixel segmentation") axes[1].set_title("superpixel segmentation")
...@@ -416,4 +439,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): ...@@ -416,4 +439,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
os.path.join( os.path.join(
save_outdir, f_out save_outdir, f_out
) )
) )
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册