提交 011cff21 编写于 作者: S sunyanfang01

add lime

上级 2484756a
......@@ -27,7 +27,6 @@ from .base import BaseAPI
class BaseClassifier(BaseAPI):
"""构建分类器,并实现其训练、评估、预测和模型导出。
Args:
model_name (str): 分类器的模型名字,取值范围为['ResNet18',
'ResNet34', 'ResNet50', 'ResNet101',
......@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI):
if mode != 'test':
label = fluid.data(dtype='int64', shape=[None, 1], name='label')
model = getattr(paddlex.cv.nets, str.lower(self.model_name))
net_out = model(image, num_classes=self.num_classes)
net_out, feat = model(image, num_classes=self.num_classes)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
inputs = OrderedDict([('image', image)])
outputs = OrderedDict([('predict', softmax_out)])
outputs = OrderedDict([('predict', softmax_out), ('net_out', feat[-1])])
if mode != 'test':
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
......@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience=5,
resume_checkpoint=None):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
......@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
"""
......@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI):
epoch_id=None,
return_details=False):
"""评估。
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。
epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。
Returns:
dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
分别表示最大值的accuracy、前5个最大值的accuracy。
......@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI):
def predict(self, img_file, transforms=None, topk=1):
"""预测。
Args:
img_file (str): 预测图像路径。
transforms (paddlex.cls.transforms): 数据预处理操作。
topk (int): 预测时前k个最大值。
Returns:
list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
分别对应预测类别id、预测类别标签、预测得分。
......@@ -279,7 +272,20 @@ class BaseClassifier(BaseAPI):
'score': result[0][0][l]
} for l in pred_label]
return res
def explanation_predict(self, images):
self.arrange_transforms(
transforms=self.test_transforms, mode='test')
new_imgs = []
for i in range(images.shape[0]):
img = images[i]
new_imgs.append(self.test_transforms(img)[0])
new_imgs = np.array(new_imgs)
result = self.exe.run(
self.test_prog,
feed={'image': new_imgs},
fetch_list=list(self.test_outputs.values()))
return result[1:]
class ResNet18(BaseClassifier):
def __init__(self, num_classes=1000):
......
import os
def imagenet_val_files_and_labels(dataset_directory):
classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}
images_path = os.path.join(dataset_directory, 'val')
filenames = []
labels = []
lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
for i, line in enumerate(lines):
class_name = line.split('\n')[0]
a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
filenames.append(f'{images_path}/{a}')
labels.append(class_to_indx[class_name])
# print(filenames[-1], labels[-1])
return filenames, labels
def _find_classes(dir):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
\ No newline at end of file
import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import cv2
import numpy as np
import six
import glob
from as_data_reader.data_path_utils import _find_classes
from PIL import Image
def resize_short(img, target_size, interpolation=None):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent = float(target_size) / min(img.shape[0], img.shape[1])
resized_width = int(round(img.shape[1] * percent))
resized_height = int(round(img.shape[0] * percent))
if interpolation:
resized = cv2.resize(
img, (resized_width, resized_height), interpolation=interpolation)
else:
resized = cv2.resize(img, (resized_width, resized_height))
return resized
def crop_image(img, target_size, center=True):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height, width = img.shape[:2]
size = target_size
if center:
w_start = (width - size) // 2
h_start = (height - size) // 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img[h_start:h_end, w_start:w_end, :]
return img
def preprocess_image(img, random_mirror=False):
"""
centered, scaled by 1/255.
:param img: np.array: shape: [ns, h, w, 3], color order: rgb.
:return: np.array: shape: [ns, h, w, 3]
"""
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# transpose to [ns, 3, h, w]
img = img.astype('float32').transpose((0, 3, 1, 2)) / 255
img_mean = np.array(mean).reshape((3, 1, 1))
img_std = np.array(std).reshape((3, 1, 1))
img -= img_mean
img /= img_std
if random_mirror:
mirror = int(np.random.uniform(0, 2))
if mirror == 1:
img = img[:, :, ::-1, :]
return img
def read_image(img_path, target_size=256, crop_size=224):
"""
resize_short to 256, then center crop to 224.
:param img_path: one image path
:return: np.array: shape: [1, h, w, 3], color order: rgb.
"""
if isinstance(img_path, str):
with open(img_path, 'rb') as f:
img = Image.open(f)
img = img.convert('RGB')
img = np.array(img)
# img = cv2.imread(img_path)
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, target_size=crop_size, center=True)
# img = img[:, :, ::-1]
img = np.expand_dims(img, axis=0)
return img
elif isinstance(img_path, np.ndarray):
assert len(img_path.shape) == 4
return img_path
else:
ValueError(f"Not recognized data type {type(img_path)}.")
class ReaderConfig(object):
"""
A generic data loader where the images are arranged in this way:
root/train/dog/xxy.jpg
root/train/dog/xxz.jpg
...
root/train/cat/nsdf3.jpg
root/train/cat/asd932_.jpg
...
root/test/dog/xxx.jpg
...
root/test/cat/123.jpg
...
"""
def __init__(self, dataset_dir, is_test):
image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test)
random_per = np.random.permutation(range(len(image_paths)))
self.image_paths = image_paths[random_per]
self.labels = labels[random_per]
self.is_test = is_test
def get_reader(self):
def reader():
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
target_size = 256
crop_size = 224
for i, img_path in enumerate(self.image_paths):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
img = cv2.imread(img_path)
if img is None:
print(img_path)
continue
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, crop_size, center=self.is_test)
img = img[:, :, ::-1]
img = np.expand_dims(img, axis=0)
img = preprocess_image(img, not self.is_test)
yield img, self.labels[i]
return reader
def get_dataset_info(self, dataset_dir, is_test=False):
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
# read
if is_test:
datasubset_dir = os.path.join(dataset_dir, 'test')
else:
datasubset_dir = os.path.join(dataset_dir, 'train')
class_names, class_to_idx = _find_classes(datasubset_dir)
# num_classes = len(class_names)
image_paths = []
labels = []
for class_name in class_names:
classes_dir = os.path.join(datasubset_dir, class_name)
for img_path in glob.glob(os.path.join(classes_dir, '*')):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
image_paths.append(img_path)
labels.append(class_to_idx[class_name])
image_paths = np.array(image_paths)
labels = np.array(labels)
return image_paths, labels, len(class_names)
def create_reader(list_image_path, list_label=None, is_test=False):
def reader():
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
target_size = 256
crop_size = 224
for i, img_path in enumerate(list_image_path):
if not img_path.lower().endswith(IMG_EXTENSIONS):
continue
img = cv2.imread(img_path)
if img is None:
print(img_path)
continue
img = resize_short(img, target_size, interpolation=None)
img = crop_image(img, crop_size, center=is_test)
img = img[:, :, ::-1]
img_show = np.expand_dims(img, axis=0)
img = preprocess_image(img_show, not is_test)
label = 0 if list_label is None else list_label[i]
yield img_show, img, label
return reader
\ No newline at end of file
import os
import paddle.fluid as fluid
import numpy as np
def paddle_get_fc_weights(var_name="fc_0.w_0"):
fc_weights = fluid.global_scope().find_var(var_name).get_tensor()
return np.array(fc_weights)
def paddle_resize(extracted_features, outsize):
resized_features = fluid.layers.resize_bilinear(extracted_features, outsize)
return resized_features
\ No newline at end of file
from .explanation_algorithms import CAM, LIME, NormLIME
class Explanation(object):
"""
Base class for all explanation algorithms.
"""
def __init__(self, explanation_algorithm_name, predict_fn, **kwargs):
supported_algorithms = {
'cam': CAM,
'lime': LIME,
'normlime': NormLIME
}
self.algorithm_name = explanation_algorithm_name.lower()
assert self.algorithm_name in supported_algorithms.keys()
self.predict_fn = predict_fn
# initialization for the explanation algorithm.
self.explain_algorithm = supported_algorithms[self.algorithm_name](
self.predict_fn, **kwargs
)
def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'):
"""
Args:
data_: data_ can be a path or numpy.ndarray.
visualization: whether to show using matplotlib.
save_to_disk: whether to save the figure in local disk.
save_dir: dir to save figure if save_to_disk is True.
Returns:
"""
return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir)
import os
import numpy as np
import time
from . import lime_base
from ..as_data_reader.readers import read_image
from ._session_preparation import paddle_get_fc_weights
import cv2
class CAM(object):
def __init__(self, predict_fn):
"""
Args:
predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
output: [
logits [N, num_classes],
feature map before global average pooling [N, num_channels, h_, w_]
]
"""
self.predict_fn = predict_fn
def preparation_cam(self, data_path):
image_show = read_image(data_path)
result = self.predict_fn(image_show)
logit = result[0][0]
if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result)
else:
probability = logit
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
fc_weights = paddle_get_fc_weights()
feature_maps = result[1]
print('predicted result: ', pred_label[0], probability[pred_label[0]])
return feature_maps, fc_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
feature_maps, fc_weights = self.preparation_cam(data_)
cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label)
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 1
ncols = 2
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(cam)
axes[1].set_title("CAM")
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'cam')
if visualization:
plt.show()
return
class LIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50):
"""
LIME wrapper. See lime_base.py for the detailed LIME implementation.
Args:
predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
num_samples: the number of samples that LIME takes for fitting.
batch_size: batch size for model inference each time.
"""
self.num_samples = num_samples
self.batch_size = batch_size
self.predict_fn = predict_fn
self.labels = None
self.image = None
self.lime_explainer = None
def preparation_lime(self, data_path):
image_show = read_image(data_path)
result = self.predict_fn(image_show)
result = result[0] # only one image here.
if abs(np.sum(result) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(result)
probability = exp_result / np.sum(exp_result)
else:
probability = result
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}')
end = time.time()
algo = lime_base.LimeImageExplainer()
explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0,
num_samples=self.num_samples, batch_size=self.batch_size)
self.lime_explainer = explainer
print('lime time: ', time.time() - end, 's.')
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
if self.lime_explainer is None:
self.preparation_lime(data_)
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 2
weights_choices = [0.6, 0.75, 0.85]
ncols = len(weights_choices)
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w)
temp, mask = self.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'lime', self.num_samples)
if visualization:
plt.show()
return
class NormLIME(object):
def __init__(self, predict_fn, num_samples=3000, batch_size=50,
kmeans_model_for_normlime=None, normlime_weights=None):
assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model."
if normlime_weights is None:
raise NotImplementedError("Computing NormLIME weights is not implemented yet.")
self.num_samples = num_samples
self.batch_size = batch_size
self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime)
self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item()
self.predict_fn = predict_fn
self.labels = None
self.image = None
def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
# global weights
g_weights = {y: [] for y in pred_labels}
for y in pred_labels:
cluster_weights_y = self.normlime_weights[y]
g_weights[y] = [
# some are not in the dict, 3000 samples may be not enough.
(i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels)
]
g_weights[y] = sorted(g_weights[y],
key=lambda x: np.abs(x[1]), reverse=True)
return g_weights
def preparation_normlime(self, data_path):
self._lime = LIME(
lambda images: self.predict_fn(images)[0],
self.num_samples,
self.batch_size
)
self._lime.preparation_lime(data_path)
image_show = read_image(data_path)
result = self.predict_fn(image_show)
logit = result[0][0] # only one image here.
if abs(np.sum(logit) - 1.0) > 1e-4:
# softmax
exp_result = np.exp(logit)
probability = exp_result / np.sum(exp_result)
else:
probability = logit
# only explain top 1
pred_label = np.argsort(probability)
pred_label = pred_label[-1:]
self.predicted_label = pred_label[0]
self.predicted_probability = probability[pred_label[0]]
self.image = image_show[0]
self.labels = pred_label
print('predicted result: ', pred_label[0], probability[pred_label[0]])
local_feature_map = result[1][0]
cluster_labels = self.predict_cluster_labels(
local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments
)
g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels)
return g_weights
def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None):
g_weights = self.preparation_normlime(data_)
lime_weights = self._lime.lime_explainer.local_exp
if visualization or save_to_disk:
import matplotlib.pyplot as plt
from skimage.segmentation import mark_boundaries
l = self.labels[0]
psize = 5
nrows = 4
weights_choices = [0.6, 0.85, 0.99]
ncols = len(weights_choices)
plt.close()
f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows))
for ax in axes.ravel():
ax.axis("off")
axes = axes.ravel()
axes[0].imshow(self.image)
axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}")
axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments))
axes[1].set_title("superpixel segmentation")
# LIME visualization
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols + i].imshow(mark_boundaries(temp, mask))
axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels")
# NormLIME visualization
self._lime.lime_explainer.local_exp = g_weights
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels")
# NormLIME*LIME visualization
combined_weights = combine_normlime_and_lime(lime_weights, g_weights)
self._lime.lime_explainer.local_exp = combined_weights
for i, w in enumerate(weights_choices):
num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w)
temp, mask = self._lime.lime_explainer.get_image_and_mask(
l, positive_only=False, hide_rest=False, num_features=num_to_show
)
axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask))
axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels")
self._lime.lime_explainer.local_exp = lime_weights
if save_to_disk and save_outdir is not None:
os.makedirs(save_outdir, exist_ok=True)
save_fig(data_, save_outdir, 'normlime', self.num_samples)
if visualization:
plt.show()
def load_kmeans_model(fname):
import pickle
with open(fname, 'rb') as f:
kmeans_model = pickle.load(f)
return kmeans_model
def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show):
segments = lime_explainer.segments
lime_weights = lime_explainer.local_exp[label]
num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8
# l1 norm with filtered weights.
used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0]
norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)])
normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)]
a = 0.0
n = 0
for i, tuple_w in enumerate(normalized_weights):
if tuple_w[1] < 0:
continue
if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp:
continue
a += tuple_w[1]
if a > percentage_to_show:
n = i + 1
break
if n == 0:
return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1)
return n
def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None):
_, nc, h, w = feature_maps.shape
cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1)
cam = cam.sum((0, 1))
if cam_min is None:
cam_min = np.min(cam)
if cam_max is None:
cam_max = np.max(cam)
cam = cam - cam_min
cam = cam / cam_max
cam = np.uint8(255 * cam)
cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR)
heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap)
cam = heatmap + np.float32(image_show)
cam = cam / np.max(cam)
return cam
def avg_using_superpixels(features, segments):
one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
for x in np.unique(segments):
one_list[x] = np.mean(features[segments == x], axis=0)
return one_list
def centroid_using_superpixels(features, segments):
from skimage.measure import regionprops
regions = regionprops(segments + 1)
one_list = np.zeros((len(np.unique(segments)), features.shape[2]))
for i, r in enumerate(regions):
one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :]
# print(one_list.shape)
return one_list
def get_feature_for_kmeans(feature_map, segments):
from sklearn.preprocessing import normalize
centroid_feature = centroid_using_superpixels(feature_map, segments)
avg_feature = avg_using_superpixels(feature_map, segments)
x = np.concatenate((centroid_feature, avg_feature), axis=-1)
x = normalize(x)
return x
def combine_normlime_and_lime(lime_weights, g_weights):
pred_labels = lime_weights.keys()
combined_weights = {y: [] for y in pred_labels}
for y in pred_labels:
normlized_lime_weights_y = lime_weights[y]
lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y}
normlized_g_weight_y = g_weights[y]
normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y}
combined_weights[y] = [
(seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k])
for seg_k in lime_weights_dict.keys()
]
combined_weights[y] = sorted(combined_weights[y],
key=lambda x: np.abs(x[1]), reverse=True)
return combined_weights
def save_fig(data_, save_outdir, algorithm_name, num_samples=3000):
import matplotlib.pyplot as plt
if isinstance(data_, str):
if algorithm_name == 'cam':
f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png"
else:
f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png"
plt.savefig(
os.path.join(save_outdir, f_out)
)
else:
n = 0
if algorithm_name == 'cam':
f_out = f'cam-{n}.png'
else:
f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
while os.path.exists(
os.path.join(save_outdir, f_out)
):
n += 1
if algorithm_name == 'cam':
f_out = f'cam-{n}.png'
else:
f_out = f'{algorithm_name}_s{num_samples}-{n}.png'
continue
plt.savefig(
os.path.join(
save_outdir, f_out
)
)
此差异已折叠。
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import copy
import os.path as osp
import numpy as np
from .core.explanation import Explanation
def visualize(img_file,
model,
explanation_type='lime',
num_samples=3000,
batch_size=50,
save_dir='./'):
model.arrange_transforms(
transforms=model.test_transforms, mode='test')
tmp_transforms = copy.deepcopy(model.test_transforms)
tmp_transforms.transforms = tmp_transforms.transforms[:-2]
img = tmp_transforms(img_file)[0]
img = np.around(img).astype('uint8')
img = np.expand_dims(img, axis=0)
explaier = None
if explanation_type == 'lime':
explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size)
else:
raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type))
img_name = osp.splitext(osp.split(img_file)[-1])[0]
explaier.explain(img, save_dir=save_dir)
def get_lime_explaier(img, model, num_samples=3000, batch_size=50):
def predict_func(image):
image = image.astype('float32')
model.test_transforms.transforms = model.test_transforms.transforms[-2:]
out = model.explanation_predict(image)
return out[0]
explaier = Explanation('lime',
predict_func,
num_samples=num_samples,
batch_size=batch_size)
return explaier
\ No newline at end of file
......@@ -120,6 +120,7 @@ class ResNet(object):
self.num_classes = num_classes
self.lr_mult_list = lr_mult_list
self.curr_stage = 0
self.features = []
def _conv_offset(self,
input,
......@@ -474,7 +475,9 @@ class ResNet(object):
size=self.num_classes,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out
self.features.append(out)
# out.persistable=True
return out, self.features
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册