未验证 提交 d51704bd 编写于 作者: S SunAhong1993 提交者: GitHub

Update explanation_algorithms.py

上级 5dc2716d
......@@ -25,7 +25,7 @@ import cv2
class CAM(object):
def __init__(self, predict_fn):
def __init__(self, predict_fn, label_names):
"""
Args:
......@@ -37,6 +37,7 @@ class CAM(object):
"""
self.predict_fn = predict_fn
self.label_names = label_names
def preparation_cam(self, data_):
image_show = read_image(data_)
......@@ -61,8 +62,13 @@ class CAM(object):
fc_weights = paddle_get_fc_weights()
feature_maps = result[1]
l = pred_label[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
......@@ -73,6 +79,9 @@ class CAM(object):
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5
nrows = 1
......@@ -84,7 +93,7 @@ class CAM(object):
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(cam)
axes[1].set_title("CAM")
......@@ -100,7 +109,7 @@ class CAM(object):
class LIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50):
def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50):
"""
LIME wrapper. See lime_base.py for the detailed LIME implementation.
Args:
......@@ -115,6 +124,7 @@ class LIME(object):
self.labels = None
self.image = None
self.lime_explainer = None
self.label_names = label_names
def preparation_lime(self, data_):
image_show = read_image(data_)
......@@ -137,8 +147,13 @@ class LIME(object):
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]:.3f}')
l = pred_label[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
print(f'predicted result: {ln} with probability {probability[pred_label[0]]:.3f}')
end = time.time()
algo = lime_base.LimeImageExplainer()
......@@ -155,6 +170,9 @@ class LIME(object):
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5
nrows = 2
......@@ -167,7 +185,7 @@ class LIME(object):
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
......@@ -179,7 +197,7 @@ class LIME(object):
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
axes[ncols + i].set_title(f"label {ln}, first {num_to_show} superpixels")
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
......@@ -192,7 +210,7 @@ class LIME(object):
class NormLIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50,
def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
kmeans_model_for_normlime=None, normlime_weights=None):
if kmeans_model_for_normlime is None:
try:
......@@ -218,6 +236,7 @@ class NormLIME(object):
self.labels = None
self.image = None
self.label_names = label_names
def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
......@@ -239,6 +258,7 @@ class NormLIME(object):
def preparation_normlime(self, data_):
self._lime = LIME(
self.predict_fn,
self.label_names,
self.num_samples,
self.batch_size
)
......@@ -273,6 +293,9 @@ class NormLIME(object):
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
ln = l
if self.label_names is not None:
ln = self.label_names[l]
psize = 5
nrows = 4
......@@ -287,7 +310,7 @@ class NormLIME(object):
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[0].set_title(f"label {ln}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
......@@ -416,4 +439,4 @@ def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
os.path.join(
save_outdir, f_out
)
)
\ No newline at end of file
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册