From 011cff21f24cb34cf1e89beeaecc41c1fd761d9c Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Thu, 14 May 2020 15:38:09 +0800 Subject: [PATCH] add lime --- paddlex/cv/models/classifier.py | 26 +- .../as_data_reader/data_path_utils.py | 27 + .../explanation/as_data_reader/readers.py | 211 ++++++++ .../explanation/core/_session_preparation.py | 13 + .../cv/models/explanation/core/explanation.py | 37 ++ .../core/explanation_algorithms.py | 458 ++++++++++++++++ .../cv/models/explanation/core/lime_base.py | 502 ++++++++++++++++++ paddlex/cv/models/explanation/visualize.py | 46 ++ paddlex/cv/nets/resnet.py | 5 +- 9 files changed, 1314 insertions(+), 11 deletions(-) create mode 100644 paddlex/cv/models/explanation/as_data_reader/data_path_utils.py create mode 100644 paddlex/cv/models/explanation/as_data_reader/readers.py create mode 100644 paddlex/cv/models/explanation/core/_session_preparation.py create mode 100644 paddlex/cv/models/explanation/core/explanation.py create mode 100644 paddlex/cv/models/explanation/core/explanation_algorithms.py create mode 100644 paddlex/cv/models/explanation/core/lime_base.py create mode 100644 paddlex/cv/models/explanation/visualize.py diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index c94e530..11f3102 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -27,7 +27,6 @@ from .base import BaseAPI class BaseClassifier(BaseAPI): """构建分类器,并实现其训练、评估、预测和模型导出。 - Args: model_name (str): 分类器的模型名字,取值范围为['ResNet18', 'ResNet34', 'ResNet50', 'ResNet101', @@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI): if mode != 'test': label = fluid.data(dtype='int64', shape=[None, 1], name='label') model = getattr(paddlex.cv.nets, str.lower(self.model_name)) - net_out = model(image, num_classes=self.num_classes) + net_out, feat = model(image, num_classes=self.num_classes) softmax_out = fluid.layers.softmax(net_out, use_cudnn=False) inputs = OrderedDict([('image', image)]) - outputs = OrderedDict([('predict', softmax_out)]) + outputs = OrderedDict([('predict', softmax_out), ('net_out', feat[-1])]) if mode != 'test': cost = fluid.layers.cross_entropy(input=softmax_out, label=label) avg_cost = fluid.layers.mean(cost) @@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI): early_stop_patience=5, resume_checkpoint=None): """训练。 - Args: num_epochs (int): 训练迭代轮数。 train_dataset (paddlex.datasets): 训练数据读取器。 @@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI): early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 连续下降或持平,则终止训练。默认值为5。 resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 - Raises: ValueError: 模型从inference model进行加载。 """ @@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI): epoch_id=None, return_details=False): """评估。 - Args: eval_dataset (paddlex.datasets): 验证数据读取器。 batch_size (int): 验证数据批大小。默认为1。 epoch_id (int): 当前评估模型所在的训练轮数。 return_details (bool): 是否返回详细信息。 - Returns: dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5', 分别表示最大值的accuracy、前5个最大值的accuracy。 @@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI): def predict(self, img_file, transforms=None, topk=1): """预测。 - Args: img_file (str): 预测图像路径。 transforms (paddlex.cls.transforms): 数据预处理操作。 topk (int): 预测时前k个最大值。 - Returns: list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score', 分别对应预测类别id、预测类别标签、预测得分。 @@ -279,7 +272,20 @@ class BaseClassifier(BaseAPI): 'score': result[0][0][l] } for l in pred_label] return res - + + def explanation_predict(self, images): + self.arrange_transforms( + transforms=self.test_transforms, mode='test') + new_imgs = [] + for i in range(images.shape[0]): + img = images[i] + new_imgs.append(self.test_transforms(img)[0]) + new_imgs = np.array(new_imgs) + result = self.exe.run( + self.test_prog, + feed={'image': new_imgs}, + fetch_list=list(self.test_outputs.values())) + return result[1:] class ResNet18(BaseClassifier): def __init__(self, num_classes=1000): diff --git a/paddlex/cv/models/explanation/as_data_reader/data_path_utils.py b/paddlex/cv/models/explanation/as_data_reader/data_path_utils.py new file mode 100644 index 0000000..225200a --- /dev/null +++ b/paddlex/cv/models/explanation/as_data_reader/data_path_utils.py @@ -0,0 +1,27 @@ +import os + + +def imagenet_val_files_and_labels(dataset_directory): + classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines() + class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))} + + images_path = os.path.join(dataset_directory, 'val') + filenames = [] + labels = [] + lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines() + for i, line in enumerate(lines): + class_name = line.split('\n')[0] + a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1) + filenames.append(f'{images_path}/{a}') + labels.append(class_to_indx[class_name]) + # print(filenames[-1], labels[-1]) + + return filenames, labels + + +def _find_classes(dir): + # Faster and available in Python 3.5 and above + classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx \ No newline at end of file diff --git a/paddlex/cv/models/explanation/as_data_reader/readers.py b/paddlex/cv/models/explanation/as_data_reader/readers.py new file mode 100644 index 0000000..bc6c201 --- /dev/null +++ b/paddlex/cv/models/explanation/as_data_reader/readers.py @@ -0,0 +1,211 @@ +import os +import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import cv2 +import numpy as np +import six +import glob +from as_data_reader.data_path_utils import _find_classes +from PIL import Image + + +def resize_short(img, target_size, interpolation=None): + """resize image + + Args: + img: image data + target_size: resize short target size + interpolation: interpolation mode + + Returns: + resized image data + """ + percent = float(target_size) / min(img.shape[0], img.shape[1]) + resized_width = int(round(img.shape[1] * percent)) + resized_height = int(round(img.shape[0] * percent)) + if interpolation: + resized = cv2.resize( + img, (resized_width, resized_height), interpolation=interpolation) + else: + resized = cv2.resize(img, (resized_width, resized_height)) + return resized + + +def crop_image(img, target_size, center=True): + """crop image + + Args: + img: images data + target_size: crop target size + center: crop mode + + Returns: + img: cropped image data + """ + height, width = img.shape[:2] + size = target_size + if center: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[h_start:h_end, w_start:w_end, :] + return img + + +def preprocess_image(img, random_mirror=False): + """ + centered, scaled by 1/255. + :param img: np.array: shape: [ns, h, w, 3], color order: rgb. + :return: np.array: shape: [ns, h, w, 3] + """ + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + # transpose to [ns, 3, h, w] + img = img.astype('float32').transpose((0, 3, 1, 2)) / 255 + + img_mean = np.array(mean).reshape((3, 1, 1)) + img_std = np.array(std).reshape((3, 1, 1)) + img -= img_mean + img /= img_std + + if random_mirror: + mirror = int(np.random.uniform(0, 2)) + if mirror == 1: + img = img[:, :, ::-1, :] + + return img + + +def read_image(img_path, target_size=256, crop_size=224): + """ + resize_short to 256, then center crop to 224. + :param img_path: one image path + :return: np.array: shape: [1, h, w, 3], color order: rgb. + """ + + if isinstance(img_path, str): + with open(img_path, 'rb') as f: + img = Image.open(f) + img = img.convert('RGB') + img = np.array(img) + # img = cv2.imread(img_path) + + img = resize_short(img, target_size, interpolation=None) + img = crop_image(img, target_size=crop_size, center=True) + # img = img[:, :, ::-1] + img = np.expand_dims(img, axis=0) + return img + elif isinstance(img_path, np.ndarray): + assert len(img_path.shape) == 4 + return img_path + else: + ValueError(f"Not recognized data type {type(img_path)}.") + + +class ReaderConfig(object): + """ + A generic data loader where the images are arranged in this way: + + root/train/dog/xxy.jpg + root/train/dog/xxz.jpg + ... + root/train/cat/nsdf3.jpg + root/train/cat/asd932_.jpg + ... + + root/test/dog/xxx.jpg + ... + root/test/cat/123.jpg + ... + + """ + def __init__(self, dataset_dir, is_test): + image_paths, labels, self.num_classes = self.get_dataset_info(dataset_dir, is_test) + random_per = np.random.permutation(range(len(image_paths))) + self.image_paths = image_paths[random_per] + self.labels = labels[random_per] + self.is_test = is_test + + def get_reader(self): + def reader(): + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + target_size = 256 + crop_size = 224 + + for i, img_path in enumerate(self.image_paths): + if not img_path.lower().endswith(IMG_EXTENSIONS): + continue + + img = cv2.imread(img_path) + if img is None: + print(img_path) + continue + img = resize_short(img, target_size, interpolation=None) + img = crop_image(img, crop_size, center=self.is_test) + img = img[:, :, ::-1] + img = np.expand_dims(img, axis=0) + + img = preprocess_image(img, not self.is_test) + + yield img, self.labels[i] + + return reader + + def get_dataset_info(self, dataset_dir, is_test=False): + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + + # read + if is_test: + datasubset_dir = os.path.join(dataset_dir, 'test') + else: + datasubset_dir = os.path.join(dataset_dir, 'train') + + class_names, class_to_idx = _find_classes(datasubset_dir) + # num_classes = len(class_names) + image_paths = [] + labels = [] + for class_name in class_names: + classes_dir = os.path.join(datasubset_dir, class_name) + for img_path in glob.glob(os.path.join(classes_dir, '*')): + if not img_path.lower().endswith(IMG_EXTENSIONS): + continue + + image_paths.append(img_path) + labels.append(class_to_idx[class_name]) + + image_paths = np.array(image_paths) + labels = np.array(labels) + return image_paths, labels, len(class_names) + + +def create_reader(list_image_path, list_label=None, is_test=False): + def reader(): + IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') + target_size = 256 + crop_size = 224 + + for i, img_path in enumerate(list_image_path): + if not img_path.lower().endswith(IMG_EXTENSIONS): + continue + + img = cv2.imread(img_path) + if img is None: + print(img_path) + continue + + img = resize_short(img, target_size, interpolation=None) + img = crop_image(img, crop_size, center=is_test) + img = img[:, :, ::-1] + img_show = np.expand_dims(img, axis=0) + + img = preprocess_image(img_show, not is_test) + + label = 0 if list_label is None else list_label[i] + + yield img_show, img, label + + return reader \ No newline at end of file diff --git a/paddlex/cv/models/explanation/core/_session_preparation.py b/paddlex/cv/models/explanation/core/_session_preparation.py new file mode 100644 index 0000000..dd43aa4 --- /dev/null +++ b/paddlex/cv/models/explanation/core/_session_preparation.py @@ -0,0 +1,13 @@ +import os +import paddle.fluid as fluid +import numpy as np + + +def paddle_get_fc_weights(var_name="fc_0.w_0"): + fc_weights = fluid.global_scope().find_var(var_name).get_tensor() + return np.array(fc_weights) + + +def paddle_resize(extracted_features, outsize): + resized_features = fluid.layers.resize_bilinear(extracted_features, outsize) + return resized_features \ No newline at end of file diff --git a/paddlex/cv/models/explanation/core/explanation.py b/paddlex/cv/models/explanation/core/explanation.py new file mode 100644 index 0000000..ce2d4e2 --- /dev/null +++ b/paddlex/cv/models/explanation/core/explanation.py @@ -0,0 +1,37 @@ +from .explanation_algorithms import CAM, LIME, NormLIME + + +class Explanation(object): + """ + Base class for all explanation algorithms. + """ + def __init__(self, explanation_algorithm_name, predict_fn, **kwargs): + supported_algorithms = { + 'cam': CAM, + 'lime': LIME, + 'normlime': NormLIME + } + + self.algorithm_name = explanation_algorithm_name.lower() + assert self.algorithm_name in supported_algorithms.keys() + self.predict_fn = predict_fn + + # initialization for the explanation algorithm. + self.explain_algorithm = supported_algorithms[self.algorithm_name]( + self.predict_fn, **kwargs + ) + + def explain(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'): + """ + + Args: + data_: data_ can be a path or numpy.ndarray. + visualization: whether to show using matplotlib. + save_to_disk: whether to save the figure in local disk. + save_dir: dir to save figure if save_to_disk is True. + + Returns: + + """ + return self.explain_algorithm.explain(data_, visualization, save_to_disk, save_dir) + diff --git a/paddlex/cv/models/explanation/core/explanation_algorithms.py b/paddlex/cv/models/explanation/core/explanation_algorithms.py new file mode 100644 index 0000000..84a4a4c --- /dev/null +++ b/paddlex/cv/models/explanation/core/explanation_algorithms.py @@ -0,0 +1,458 @@ +import os +import numpy as np +import time + +from . import lime_base +from ..as_data_reader.readers import read_image +from ._session_preparation import paddle_get_fc_weights + +import cv2 + + +class CAM(object): + def __init__(self, predict_fn): + """ + + Args: + predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255) + output: [ + logits [N, num_classes], + feature map before global average pooling [N, num_channels, h_, w_] + ] + + """ + self.predict_fn = predict_fn + + def preparation_cam(self, data_path): + image_show = read_image(data_path) + result = self.predict_fn(image_show) + + logit = result[0][0] + if abs(np.sum(logit) - 1.0) > 1e-4: + # softmax + exp_result = np.exp(logit) + probability = exp_result / np.sum(exp_result) + else: + probability = logit + + # only explain top 1 + pred_label = np.argsort(probability) + pred_label = pred_label[-1:] + + self.predicted_label = pred_label[0] + self.predicted_probability = probability[pred_label[0]] + self.image = image_show[0] + self.labels = pred_label + + fc_weights = paddle_get_fc_weights() + feature_maps = result[1] + + print('predicted result: ', pred_label[0], probability[pred_label[0]]) + return feature_maps, fc_weights + + def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + feature_maps, fc_weights = self.preparation_cam(data_) + cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label) + + if visualization or save_to_disk: + import matplotlib.pyplot as plt + from skimage.segmentation import mark_boundaries + l = self.labels[0] + + psize = 5 + nrows = 1 + ncols = 2 + + plt.close() + f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + for ax in axes.ravel(): + ax.axis("off") + axes = axes.ravel() + axes[0].imshow(self.image) + axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") + + axes[1].imshow(cam) + axes[1].set_title("CAM") + + if save_to_disk and save_outdir is not None: + os.makedirs(save_outdir, exist_ok=True) + save_fig(data_, save_outdir, 'cam') + + if visualization: + plt.show() + + return + + +class LIME(object): + def __init__(self, predict_fn, num_samples=3000, batch_size=50): + """ + LIME wrapper. See lime_base.py for the detailed LIME implementation. + Args: + predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME. + num_samples: the number of samples that LIME takes for fitting. + batch_size: batch size for model inference each time. + """ + self.num_samples = num_samples + self.batch_size = batch_size + + self.predict_fn = predict_fn + self.labels = None + self.image = None + self.lime_explainer = None + + def preparation_lime(self, data_path): + image_show = read_image(data_path) + result = self.predict_fn(image_show) + + result = result[0] # only one image here. + + if abs(np.sum(result) - 1.0) > 1e-4: + # softmax + exp_result = np.exp(result) + probability = exp_result / np.sum(exp_result) + else: + probability = result + + # only explain top 1 + pred_label = np.argsort(probability) + pred_label = pred_label[-1:] + + self.predicted_label = pred_label[0] + self.predicted_probability = probability[pred_label[0]] + self.image = image_show[0] + self.labels = pred_label + + print(f'predicted result: {pred_label[0]} with probability {probability[pred_label[0]]: .3f}') + + end = time.time() + algo = lime_base.LimeImageExplainer() + explainer = algo.explain_instance(self.image, self.predict_fn, self.labels, 0, + num_samples=self.num_samples, batch_size=self.batch_size) + self.lime_explainer = explainer + print('lime time: ', time.time() - end, 's.') + + def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + if self.lime_explainer is None: + self.preparation_lime(data_) + + if visualization or save_to_disk: + import matplotlib.pyplot as plt + from skimage.segmentation import mark_boundaries + l = self.labels[0] + + psize = 5 + nrows = 2 + weights_choices = [0.6, 0.75, 0.85] + ncols = len(weights_choices) + + plt.close() + f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + for ax in axes.ravel(): + ax.axis("off") + axes = axes.ravel() + axes[0].imshow(self.image) + axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") + + axes[1].imshow(mark_boundaries(self.image, self.lime_explainer.segments)) + axes[1].set_title("superpixel segmentation") + + # LIME visualization + for i, w in enumerate(weights_choices): + num_to_show = auto_choose_num_features_to_show(self.lime_explainer, l, w) + temp, mask = self.lime_explainer.get_image_and_mask( + l, positive_only=False, hide_rest=False, num_features=num_to_show + ) + axes[ncols + i].imshow(mark_boundaries(temp, mask)) + axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels") + + if save_to_disk and save_outdir is not None: + os.makedirs(save_outdir, exist_ok=True) + save_fig(data_, save_outdir, 'lime', self.num_samples) + + if visualization: + plt.show() + + return + + +class NormLIME(object): + def __init__(self, predict_fn, num_samples=3000, batch_size=50, + kmeans_model_for_normlime=None, normlime_weights=None): + assert kmeans_model_for_normlime is not None, "NormLIME needs the KMeans model." + if normlime_weights is None: + raise NotImplementedError("Computing NormLIME weights is not implemented yet.") + + self.num_samples = num_samples + self.batch_size = batch_size + + self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime) + self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item() + + self.predict_fn = predict_fn + + self.labels = None + self.image = None + + def predict_cluster_labels(self, feature_map, segments): + return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments)) + + def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels): + # global weights + g_weights = {y: [] for y in pred_labels} + for y in pred_labels: + cluster_weights_y = self.normlime_weights[y] + g_weights[y] = [ + # some are not in the dict, 3000 samples may be not enough. + (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels) + ] + + g_weights[y] = sorted(g_weights[y], + key=lambda x: np.abs(x[1]), reverse=True) + + return g_weights + + def preparation_normlime(self, data_path): + self._lime = LIME( + lambda images: self.predict_fn(images)[0], + self.num_samples, + self.batch_size + ) + self._lime.preparation_lime(data_path) + + image_show = read_image(data_path) + result = self.predict_fn(image_show) + + logit = result[0][0] # only one image here. + if abs(np.sum(logit) - 1.0) > 1e-4: + # softmax + exp_result = np.exp(logit) + probability = exp_result / np.sum(exp_result) + else: + probability = logit + + # only explain top 1 + pred_label = np.argsort(probability) + pred_label = pred_label[-1:] + + self.predicted_label = pred_label[0] + self.predicted_probability = probability[pred_label[0]] + self.image = image_show[0] + self.labels = pred_label + print('predicted result: ', pred_label[0], probability[pred_label[0]]) + + local_feature_map = result[1][0] + cluster_labels = self.predict_cluster_labels( + local_feature_map.transpose((1, 2, 0)), self._lime.lime_explainer.segments + ) + + g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels) + + return g_weights + + def explain(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + g_weights = self.preparation_normlime(data_) + lime_weights = self._lime.lime_explainer.local_exp + + if visualization or save_to_disk: + import matplotlib.pyplot as plt + from skimage.segmentation import mark_boundaries + l = self.labels[0] + + psize = 5 + nrows = 4 + weights_choices = [0.6, 0.85, 0.99] + ncols = len(weights_choices) + + plt.close() + f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + for ax in axes.ravel(): + ax.axis("off") + + axes = axes.ravel() + axes[0].imshow(self.image) + axes[0].set_title(f"label {l}, proba: {self.predicted_probability: .3f}") + + axes[1].imshow(mark_boundaries(self.image, self._lime.lime_explainer.segments)) + axes[1].set_title("superpixel segmentation") + + # LIME visualization + for i, w in enumerate(weights_choices): + num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w) + temp, mask = self._lime.lime_explainer.get_image_and_mask( + l, positive_only=False, hide_rest=False, num_features=num_to_show + ) + axes[ncols + i].imshow(mark_boundaries(temp, mask)) + axes[ncols + i].set_title(f"label {l}, first {num_to_show} superpixels") + + # NormLIME visualization + self._lime.lime_explainer.local_exp = g_weights + for i, w in enumerate(weights_choices): + num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w) + temp, mask = self._lime.lime_explainer.get_image_and_mask( + l, positive_only=False, hide_rest=False, num_features=num_to_show + ) + axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask)) + axes[ncols * 2 + i].set_title(f"label {l}, first {num_to_show} superpixels") + + # NormLIME*LIME visualization + combined_weights = combine_normlime_and_lime(lime_weights, g_weights) + self._lime.lime_explainer.local_exp = combined_weights + for i, w in enumerate(weights_choices): + num_to_show = auto_choose_num_features_to_show(self._lime.lime_explainer, l, w) + temp, mask = self._lime.lime_explainer.get_image_and_mask( + l, positive_only=False, hide_rest=False, num_features=num_to_show + ) + axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask)) + axes[ncols * 3 + i].set_title(f"label {l}, first {num_to_show} superpixels") + + self._lime.lime_explainer.local_exp = lime_weights + + if save_to_disk and save_outdir is not None: + os.makedirs(save_outdir, exist_ok=True) + save_fig(data_, save_outdir, 'normlime', self.num_samples) + + if visualization: + plt.show() + + +def load_kmeans_model(fname): + import pickle + with open(fname, 'rb') as f: + kmeans_model = pickle.load(f) + + return kmeans_model + + +def auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show): + segments = lime_explainer.segments + lime_weights = lime_explainer.local_exp[label] + num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8 + + # l1 norm with filtered weights. + used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0] + norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)]) + normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)] + + a = 0.0 + n = 0 + for i, tuple_w in enumerate(normalized_weights): + if tuple_w[1] < 0: + continue + if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp: + continue + + a += tuple_w[1] + if a > percentage_to_show: + n = i + 1 + break + + if n == 0: + return auto_choose_num_features_to_show(lime_explainer, label, percentage_to_show-0.1) + + return n + + +def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None): + _, nc, h, w = feature_maps.shape + + cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1) + cam = cam.sum((0, 1)) + + if cam_min is None: + cam_min = np.min(cam) + if cam_max is None: + cam_max = np.max(cam) + + cam = cam - cam_min + cam = cam / cam_max + cam = np.uint8(255 * cam) + cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR) + + heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET) + heatmap = np.float32(heatmap) + cam = heatmap + np.float32(image_show) + cam = cam / np.max(cam) + + return cam + + +def avg_using_superpixels(features, segments): + one_list = np.zeros((len(np.unique(segments)), features.shape[2])) + for x in np.unique(segments): + one_list[x] = np.mean(features[segments == x], axis=0) + + return one_list + + +def centroid_using_superpixels(features, segments): + from skimage.measure import regionprops + regions = regionprops(segments + 1) + one_list = np.zeros((len(np.unique(segments)), features.shape[2])) + for i, r in enumerate(regions): + one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :] + # print(one_list.shape) + return one_list + + +def get_feature_for_kmeans(feature_map, segments): + from sklearn.preprocessing import normalize + centroid_feature = centroid_using_superpixels(feature_map, segments) + avg_feature = avg_using_superpixels(feature_map, segments) + x = np.concatenate((centroid_feature, avg_feature), axis=-1) + x = normalize(x) + return x + + +def combine_normlime_and_lime(lime_weights, g_weights): + pred_labels = lime_weights.keys() + combined_weights = {y: [] for y in pred_labels} + + for y in pred_labels: + normlized_lime_weights_y = lime_weights[y] + lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y} + + normlized_g_weight_y = g_weights[y] + normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y} + + combined_weights[y] = [ + (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k]) + for seg_k in lime_weights_dict.keys() + ] + + combined_weights[y] = sorted(combined_weights[y], + key=lambda x: np.abs(x[1]), reverse=True) + + return combined_weights + + +def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): + import matplotlib.pyplot as plt + if isinstance(data_, str): + if algorithm_name == 'cam': + f_out = f"{algorithm_name}_{data_.split('/')[-1]}.png" + else: + f_out = f"{algorithm_name}_{data_.split('/')[-1]}_s{num_samples}.png" + plt.savefig( + os.path.join(save_outdir, f_out) + ) + else: + n = 0 + if algorithm_name == 'cam': + f_out = f'cam-{n}.png' + else: + f_out = f'{algorithm_name}_s{num_samples}-{n}.png' + while os.path.exists( + os.path.join(save_outdir, f_out) + ): + n += 1 + if algorithm_name == 'cam': + f_out = f'cam-{n}.png' + else: + f_out = f'{algorithm_name}_s{num_samples}-{n}.png' + continue + plt.savefig( + os.path.join( + save_outdir, f_out + ) + ) diff --git a/paddlex/cv/models/explanation/core/lime_base.py b/paddlex/cv/models/explanation/core/lime_base.py new file mode 100644 index 0000000..553e8a4 --- /dev/null +++ b/paddlex/cv/models/explanation/core/lime_base.py @@ -0,0 +1,502 @@ +""" +Contains abstract functionality for learning locally linear sparse model. +""" +from __future__ import print_function +import numpy as np +import scipy as sp +import sklearn +import sklearn.preprocessing +from skimage.color import gray2rgb +from sklearn.linear_model import Ridge, lars_path +from sklearn.utils import check_random_state + +import copy +from functools import partial +from skimage.segmentation import quickshift +from skimage.measure import regionprops + + +class LimeBase(object): + """Class for learning a locally linear sparse model from perturbed data""" + def __init__(self, + kernel_fn, + verbose=False, + random_state=None): + """Init function + + Args: + kernel_fn: function that transforms an array of distances into an + array of proximity values (floats). + verbose: if true, print local prediction values from linear model. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + """ + self.kernel_fn = kernel_fn + self.verbose = verbose + self.random_state = check_random_state(random_state) + + @staticmethod + def generate_lars_path(weighted_data, weighted_labels): + """Generates the lars path for weighted data. + + Args: + weighted_data: data that has been weighted by kernel + weighted_label: labels, weighted by kernel + + Returns: + (alphas, coefs), both are arrays corresponding to the + regularization parameter and coefficients, respectively + """ + x_vector = weighted_data + alphas, _, coefs = lars_path(x_vector, + weighted_labels, + method='lasso', + verbose=False) + return alphas, coefs + + def forward_selection(self, data, labels, weights, num_features): + """Iteratively adds features to the model""" + clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) + used_features = [] + for _ in range(min(num_features, data.shape[1])): + max_ = -100000000 + best = 0 + for feature in range(data.shape[1]): + if feature in used_features: + continue + clf.fit(data[:, used_features + [feature]], labels, + sample_weight=weights) + score = clf.score(data[:, used_features + [feature]], + labels, + sample_weight=weights) + if score > max_: + best = feature + max_ = score + used_features.append(best) + return np.array(used_features) + + def feature_selection(self, data, labels, weights, num_features, method): + """Selects features for the model. see explain_instance_with_data to + understand the parameters.""" + if method == 'none': + return np.array(range(data.shape[1])) + elif method == 'forward_selection': + return self.forward_selection(data, labels, weights, num_features) + elif method == 'highest_weights': + clf = Ridge(alpha=0.01, fit_intercept=True, + random_state=self.random_state) + clf.fit(data, labels, sample_weight=weights) + + coef = clf.coef_ + if sp.sparse.issparse(data): + coef = sp.sparse.csr_matrix(clf.coef_) + weighted_data = coef.multiply(data[0]) + # Note: most efficient to slice the data before reversing + sdata = len(weighted_data.data) + argsort_data = np.abs(weighted_data.data).argsort() + # Edge case where data is more sparse than requested number of feature importances + # In that case, we just pad with zero-valued features + if sdata < num_features: + nnz_indexes = argsort_data[::-1] + indices = weighted_data.indices[nnz_indexes] + num_to_pad = num_features - sdata + indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype))) + indices_set = set(indices) + pad_counter = 0 + for i in range(data.shape[1]): + if i not in indices_set: + indices[pad_counter + sdata] = i + pad_counter += 1 + if pad_counter >= num_to_pad: + break + else: + nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] + indices = weighted_data.indices[nnz_indexes] + return indices + else: + weighted_data = coef * data[0] + feature_weights = sorted( + zip(range(data.shape[1]), weighted_data), + key=lambda x: np.abs(x[1]), + reverse=True) + return np.array([x[0] for x in feature_weights[:num_features]]) + elif method == 'lasso_path': + weighted_data = ((data - np.average(data, axis=0, weights=weights)) + * np.sqrt(weights[:, np.newaxis])) + weighted_labels = ((labels - np.average(labels, weights=weights)) + * np.sqrt(weights)) + nonzero = range(weighted_data.shape[1]) + _, coefs = self.generate_lars_path(weighted_data, + weighted_labels) + for i in range(len(coefs.T) - 1, 0, -1): + nonzero = coefs.T[i].nonzero()[0] + if len(nonzero) <= num_features: + break + used_features = nonzero + return used_features + elif method == 'auto': + if num_features <= 6: + n_method = 'forward_selection' + else: + n_method = 'highest_weights' + return self.feature_selection(data, labels, weights, + num_features, n_method) + + def explain_instance_with_data(self, + neighborhood_data, + neighborhood_labels, + distances, + label, + num_features, + feature_selection='auto', + model_regressor=None): + """Takes perturbed data, labels and distances, returns explanation. + + Args: + neighborhood_data: perturbed data, 2d array. first element is + assumed to be the original data point. + neighborhood_labels: corresponding perturbed labels. should have as + many columns as the number of possible labels. + distances: distances to original data point. + label: label for which we want an explanation + num_features: maximum number of features in explanation + feature_selection: how to select num_features. options are: + 'forward_selection': iteratively add features to the model. + This is costly when num_features is high + 'highest_weights': selects the features that have the highest + product of absolute weight * original data point when + learning with all the features + 'lasso_path': chooses features based on the lasso + regularization path + 'none': uses all features, ignores num_features + 'auto': uses forward_selection if num_features <= 6, and + 'highest_weights' otherwise. + model_regressor: sklearn regressor to use in explanation. + Defaults to Ridge regression if None. Must have + model_regressor.coef_ and 'sample_weight' as a parameter + to model_regressor.fit() + + Returns: + (intercept, exp, score, local_pred): + intercept is a float. + exp is a sorted list of tuples, where each tuple (x,y) corresponds + to the feature id (x) and the local weight (y). The list is sorted + by decreasing absolute value of y. + score is the R^2 value of the returned explanation + local_pred is the prediction of the explanation model on the original instance + """ + + weights = self.kernel_fn(distances) + labels_column = neighborhood_labels[:, label] + used_features = self.feature_selection(neighborhood_data, + labels_column, + weights, + num_features, + feature_selection) + if model_regressor is None: + model_regressor = Ridge(alpha=1, fit_intercept=True, + random_state=self.random_state) + easy_model = model_regressor + easy_model.fit(neighborhood_data[:, used_features], + labels_column, sample_weight=weights) + prediction_score = easy_model.score( + neighborhood_data[:, used_features], + labels_column, sample_weight=weights) + + local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) + + if self.verbose: + print('Intercept', easy_model.intercept_) + print('Prediction_local', local_pred,) + print('Right:', neighborhood_labels[0, label]) + return (easy_model.intercept_, + sorted(zip(used_features, easy_model.coef_), + key=lambda x: np.abs(x[1]), reverse=True), + prediction_score, local_pred) + + +class ImageExplanation(object): + def __init__(self, image, segments): + """Init function. + + Args: + image: 3d numpy array + segments: 2d numpy array, with the output from skimage.segmentation + """ + self.image = image + self.segments = segments + self.intercept = {} + self.local_exp = {} + self.local_pred = None + + def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, + num_features=5, min_weight=0.): + """Init function. + + Args: + label: label to explain + positive_only: if True, only take superpixels that positively contribute to + the prediction of the label. + negative_only: if True, only take superpixels that negatively contribute to + the prediction of the label. If false, and so is positive_only, then both + negativey and positively contributions will be taken. + Both can't be True at the same time + hide_rest: if True, make the non-explanation part of the return + image gray + num_features: number of superpixels to include in explanation + min_weight: minimum weight of the superpixels to include in explanation + + Returns: + (image, mask), where image is a 3d numpy array and mask is a 2d + numpy array that can be used with + skimage.segmentation.mark_boundaries + """ + if label not in self.local_exp: + raise KeyError('Label not in explanation') + if positive_only & negative_only: + raise ValueError("Positive_only and negative_only cannot be true at the same time.") + segments = self.segments + image = self.image + exp = self.local_exp[label] + mask = np.zeros(segments.shape, segments.dtype) + if hide_rest: + temp = np.zeros(self.image.shape) + else: + temp = self.image.copy() + if positive_only: + fs = [x[0] for x in exp + if x[1] > 0 and x[1] > min_weight][:num_features] + if negative_only: + fs = [x[0] for x in exp + if x[1] < 0 and abs(x[1]) > min_weight][:num_features] + if positive_only or negative_only: + for f in fs: + temp[segments == f] = image[segments == f].copy() + mask[segments == f] = 1 + return temp, mask + else: + for f, w in exp[:num_features]: + if np.abs(w) < min_weight: + continue + c = 0 if w < 0 else 1 + mask[segments == f] = -1 if w < 0 else 1 + temp[segments == f] = image[segments == f].copy() + temp[segments == f, c] = np.max(image) + return temp, mask + + def get_rendered_image(self, label, min_weight=0.005): + """ + + Args: + label: label to explain + min_weight: + + Returns: + image, is a 3d numpy array + """ + if label not in self.local_exp: + raise KeyError('Label not in explanation') + + from matplotlib import cm + + segments = self.segments + image = self.image + exp = self.local_exp[label] + temp = np.zeros_like(image) + + weight_max = abs(exp[0][1]) + exp = [(f, w/weight_max) for f, w in exp] + exp = sorted(exp, key=lambda x: x[1], reverse=True) # negatives are at last. + + cmaps = cm.get_cmap('Spectral') + # sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp)))) + colors = cmaps(np.linspace(0, 1, len(exp))) + colors = colors[:, :3] + + for i, (f, w) in enumerate(exp): + if np.abs(w) < min_weight: + continue + temp[segments == f] = image[segments == f].copy() + temp[segments == f] = colors[i] * 255 + return temp + + +class LimeImageExplainer(object): + """Explains predictions on Image (i.e. matrix) data. + For numerical features, perturb them by sampling from a Normal(0,1) and + doing the inverse operation of mean-centering and scaling, according to the + means and stds in the training data. For categorical features, perturb by + sampling according to the training distribution, and making a binary + feature that is 1 when the value is the same as the instance being + explained.""" + + def __init__(self, kernel_width=.25, kernel=None, verbose=False, + feature_selection='auto', random_state=None): + """Init function. + + Args: + kernel_width: kernel width for the exponential kernel. + If None, defaults to sqrt(number of columns) * 0.75. + kernel: similarity kernel that takes euclidean distances and kernel + width as input and outputs weights in (0,1). If None, defaults to + an exponential kernel. + verbose: if true, print local prediction values from linear model + feature_selection: feature selection method. can be + 'forward_selection', 'lasso_path', 'none' or 'auto'. + See function 'explain_instance_with_data' in lime_base.py for + details on what each of the options does. + random_state: an integer or numpy.RandomState that will be used to + generate random numbers. If None, the random state will be + initialized using the internal numpy seed. + """ + kernel_width = float(kernel_width) + + if kernel is None: + def kernel(d, kernel_width): + return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) + + kernel_fn = partial(kernel, kernel_width=kernel_width) + + self.random_state = check_random_state(random_state) + self.feature_selection = feature_selection + self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state) + + def explain_instance(self, image, classifier_fn, labels=(1,), + hide_color=None, + num_features=100000, num_samples=1000, + batch_size=10, + distance_metric='cosine', + model_regressor=None + ): + """Generates explanations for a prediction. + + First, we generate neighborhood data by randomly perturbing features + from the instance (see __data_inverse). We then learn locally weighted + linear models on this neighborhood data to explain each of the classes + in an interpretable way (see lime_base.py). + + Args: + image: 3 dimension RGB image. If this is only two dimensional, + we will assume it's a grayscale image and call gray2rgb. + classifier_fn: classifier prediction probability function, which + takes a numpy array and outputs prediction probabilities. For + ScikitClassifiers , this is classifier.predict_proba. + labels: iterable with labels to be explained. + hide_color: TODO + num_features: maximum number of features present in explanation + num_samples: size of the neighborhood to learn the linear model + batch_size: TODO + distance_metric: the distance metric to use for weights. + model_regressor: sklearn regressor to use in explanation. Defaults + to Ridge regression in LimeBase. Must have model_regressor.coef_ + and 'sample_weight' as a parameter to model_regressor.fit() + + Returns: + An ImageExplanation object (see lime_image.py) with the corresponding + explanations. + """ + if len(image.shape) == 2: + image = gray2rgb(image) + + try: + segments = quickshift(image, sigma=1) + except ValueError as e: + raise e + + self.segments = segments + + fudged_image = image.copy() + if hide_color is None: + # if no hide_color, use the mean + for x in np.unique(segments): + mx = np.mean(image[segments == x], axis=0) + fudged_image[segments == x] = mx + elif hide_color == 'avg_from_neighbor': + from scipy.spatial.distance import cdist + + n_features = np.unique(segments).shape[0] + regions = regionprops(segments + 1) + centroids = np.zeros((n_features, 2)) + for i, x in enumerate(regions): + centroids[i] = np.array(x.centroid) + + d = cdist(centroids, centroids, 'sqeuclidean') + + for x in np.unique(segments): + # print(np.argmin(d[x])) + a = [image[segments == i] for i in np.argsort(d[x])[1:6]] + mx = np.mean(np.concatenate(a), axis=0) + fudged_image[segments == x] = mx + + else: + fudged_image[:] = 0 + + top = labels + + data, labels = self.data_labels(image, fudged_image, segments, + classifier_fn, num_samples, + batch_size=batch_size) + + distances = sklearn.metrics.pairwise_distances( + data, + data[0].reshape(1, -1), + metric=distance_metric + ).ravel() + + ret_exp = ImageExplanation(image, segments) + for label in top: + (ret_exp.intercept[label], + ret_exp.local_exp[label], + ret_exp.score, ret_exp.local_pred) = self.base.explain_instance_with_data( + data, labels, distances, label, num_features, + model_regressor=model_regressor, + feature_selection=self.feature_selection) + return ret_exp + + def data_labels(self, + image, + fudged_image, + segments, + classifier_fn, + num_samples, + batch_size=10): + """Generates images and predictions in the neighborhood of this image. + + Args: + image: 3d numpy array, the image + fudged_image: 3d numpy array, image to replace original image when + superpixel is turned off + segments: segmentation of the image + classifier_fn: function that takes a list of images and returns a + matrix of prediction probabilities + num_samples: size of the neighborhood to learn the linear model + batch_size: classifier_fn will be called on batches of this size. + + Returns: + A tuple (data, labels), where: + data: dense num_samples * num_superpixels + labels: prediction probabilities matrix + """ + n_features = np.unique(segments).shape[0] + data = self.random_state.randint(0, 2, num_samples * n_features) \ + .reshape((num_samples, n_features)) + labels = [] + data[0, :] = 1 + imgs = [] + for row in data: + temp = copy.deepcopy(image) + zeros = np.where(row == 0)[0] + mask = np.zeros(segments.shape).astype(bool) + for z in zeros: + mask[segments == z] = True + temp[mask] = fudged_image[mask] + imgs.append(temp) + if len(imgs) == batch_size: + preds = classifier_fn(np.array(imgs)) + labels.extend(preds) + imgs = [] + if len(imgs) > 0: + preds = classifier_fn(np.array(imgs)) + labels.extend(preds) + return data, np.array(labels) diff --git a/paddlex/cv/models/explanation/visualize.py b/paddlex/cv/models/explanation/visualize.py new file mode 100644 index 0000000..bd28566 --- /dev/null +++ b/paddlex/cv/models/explanation/visualize.py @@ -0,0 +1,46 @@ +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import copy +import os.path as osp +import numpy as np +from .core.explanation import Explanation + + +def visualize(img_file, + model, + explanation_type='lime', + num_samples=3000, + batch_size=50, + save_dir='./'): + model.arrange_transforms( + transforms=model.test_transforms, mode='test') + tmp_transforms = copy.deepcopy(model.test_transforms) + tmp_transforms.transforms = tmp_transforms.transforms[:-2] + img = tmp_transforms(img_file)[0] + img = np.around(img).astype('uint8') + img = np.expand_dims(img, axis=0) + explaier = None + if explanation_type == 'lime': + explaier = get_lime_explaier(img, model, num_samples=num_samples, batch_size=batch_size) + else: + raise Exception('The {} explanantion method is not supported yet!'.format(explanation_type)) + img_name = osp.splitext(osp.split(img_file)[-1])[0] + explaier.explain(img, save_dir=save_dir) + + +def get_lime_explaier(img, model, num_samples=3000, batch_size=50): + def predict_func(image): + image = image.astype('float32') + model.test_transforms.transforms = model.test_transforms.transforms[-2:] + out = model.explanation_predict(image) + return out[0] + explaier = Explanation('lime', + predict_func, + num_samples=num_samples, + batch_size=batch_size) + return explaier + \ No newline at end of file diff --git a/paddlex/cv/nets/resnet.py b/paddlex/cv/nets/resnet.py index 40c6965..4589d79 100644 --- a/paddlex/cv/nets/resnet.py +++ b/paddlex/cv/nets/resnet.py @@ -120,6 +120,7 @@ class ResNet(object): self.num_classes = num_classes self.lr_mult_list = lr_mult_list self.curr_stage = 0 + self.features = [] def _conv_offset(self, input, @@ -474,7 +475,9 @@ class ResNet(object): size=self.num_classes, param_attr=fluid.param_attr.ParamAttr( initializer=fluid.initializer.Uniform(-stdv, stdv))) - return out + self.features.append(out) +# out.persistable=True + return out, self.features return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat) for idx, feat in enumerate(res_endpoints)]) -- GitLab