diff --git a/deploy/cpp/src/paddlex.cpp b/deploy/cpp/src/paddlex.cpp index b3e292c23e781d675ad7e23512fe96672d4b8121..90a4a4452b9e5f3eba1c0b4c7ab88f5b91e03971 100644 --- a/deploy/cpp/src/paddlex.cpp +++ b/deploy/cpp/src/paddlex.cpp @@ -98,7 +98,7 @@ bool Model::load_config(const std::string& model_dir) { bool Model::preprocess(const cv::Mat& input_im, ImageBlob* blob) { cv::Mat im = input_im.clone(); - if (!transforms_.Run(&im, &inputs_)) { + if (!transforms_.Run(&im, blob)) { return false; } return true; diff --git a/docs/FAQ.md b/docs/FAQ.md index 8da14f32e428f868f637a395223855aa66371bbf..b120ebd10ed791c65c3f65e611c5b45da2a9211f 100755 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -60,3 +60,9 @@ ## 11. 每次训练新的模型,都需要重新下载预训练模型,怎样可以下载一次就搞定 > 1.可以按照9的方式来解决这个问题 > 2.每次训练前都设定`paddlex.pretrain_dir`路径,如设定`paddlex.pretrain_dir='/usrname/paddlex`,如此下载完的预训练模型会存放至`/usrname/paddlex`目录下,而已经下载在该目录的模型也不会再次重复下载 + +## 12. 程序启动时提示"Failed to execute script PaddleX",如何解决? +> 1. 请检查目标机器上PaddleX程序所在路径是否包含中文。目前暂不支持中文路径,请尝试将程序移动到英文目录。 +> 2. 如果您的系统是Windows 7或者Windows Server 2012时,原因是缺少MFPlat.DLL/MF.dll/MFReadWrite.dll等OpenCV依赖的DLL,请按如下方式安装桌面体验:通过“我的电脑”-->“属性”-->"管理"打开服务器管理器,点击右上角“管理”选择“添加角色和功能”。点击“服务器选择”-->“功能”,拖动滚动条到最下端,点开“用户界面和基础结构”,勾选“桌面体验”后点击“安装”,等安装完成尝试再次运行PaddleX。 +> 3. 请检查目标机器上是否有其他的PaddleX程序或者进程在运行中,如有请退出或者重启机器看是否解决 +> 4. 请确认运行程序的用户是否有管理员权限,如非管理员权限用户请尝试使用管理员运行看是否成功 \ No newline at end of file diff --git a/paddlex/__init__.py b/paddlex/__init__.py index d1656161a0d764c0a7fbd125f246d0e43125bcda..b80363f2e6adfdbd6ce712cfec486540753abbb7 100644 --- a/paddlex/__init__.py +++ b/paddlex/__init__.py @@ -53,4 +53,4 @@ log_level = 2 from . import interpret -__version__ = '1.0.5' +__version__ = '1.0.6' diff --git a/paddlex/cls.py b/paddlex/cls.py index 0dce289d7ee77c9559a4fce2104cca8786b81f52..90c5eefce512c966a04975ebfe6457613012c872 100644 --- a/paddlex/cls.py +++ b/paddlex/cls.py @@ -37,5 +37,6 @@ DenseNet161 = cv.models.DenseNet161 DenseNet201 = cv.models.DenseNet201 ShuffleNetV2 = cv.models.ShuffleNetV2 HRNet_W18 = cv.models.HRNet_W18 +AlexNet = cv.models.AlexNet transforms = cv.transforms.cls_transforms diff --git a/paddlex/cv/models/__init__.py b/paddlex/cv/models/__init__.py index 22485f2701e1e06c6e050c0c15238c32ed4a6a02..622878933c12f1934960eb42aed1f992e7164708 100644 --- a/paddlex/cv/models/__init__.py +++ b/paddlex/cv/models/__init__.py @@ -35,6 +35,7 @@ from .classifier import DenseNet161 from .classifier import DenseNet201 from .classifier import ShuffleNetV2 from .classifier import HRNet_W18 +from .classifier import AlexNet from .base import BaseAPI from .yolo_v3 import YOLOv3 from .faster_rcnn import FasterRCNN diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index ac8989ff83980bf45d7705985353435e6e19a9e6..14db42b8aed39674f2911f3fe5ee472435b8da34 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -221,8 +221,8 @@ class BaseAPI: logging.info( "Load pretrain weights from {}.".format(pretrain_weights), use_color=True) - paddlex.utils.utils.load_pretrain_weights( - self.exe, self.train_prog, pretrain_weights, fuse_bn) + paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog, + pretrain_weights, fuse_bn) # 进行裁剪 if sensitivities_file is not None: import paddleslim @@ -262,6 +262,7 @@ class BaseAPI: info['_Attributes']['num_classes'] = self.num_classes info['_Attributes']['labels'] = self.labels + info['_Attributes']['fixed_input_shape'] = self.fixed_input_shape try: primary_metric_key = list(self.eval_metrics.keys())[0] primary_metric_value = float(self.eval_metrics[primary_metric_key]) @@ -325,9 +326,7 @@ class BaseAPI: logging.info("Model saved in {}.".format(save_dir)) def export_inference_model(self, save_dir): - test_input_names = [ - var.name for var in list(self.test_inputs.values()) - ] + test_input_names = [var.name for var in list(self.test_inputs.values())] test_outputs = list(self.test_outputs.values()) if self.__class__.__name__ == 'MaskRCNN': from paddlex.utils.save import save_mask_inference_model @@ -364,8 +363,7 @@ class BaseAPI: # 模型保存成功的标志 open(osp.join(save_dir, '.success'), 'w').close() - logging.info("Model for inference deploy saved in {}.".format( - save_dir)) + logging.info("Model for inference deploy saved in {}.".format(save_dir)) def train_loop(self, num_epochs, @@ -489,13 +487,11 @@ class BaseAPI: eta = ((num_epochs - i) * total_num_steps - step - 1 ) * avg_step_time if time_eval_one_epoch is not None: - eval_eta = ( - total_eval_times - i // save_interval_epochs - ) * time_eval_one_epoch + eval_eta = (total_eval_times - i // save_interval_epochs + ) * time_eval_one_epoch else: - eval_eta = ( - total_eval_times - i // save_interval_epochs - ) * total_num_steps_eval * avg_step_time + eval_eta = (total_eval_times - i // save_interval_epochs + ) * total_num_steps_eval * avg_step_time eta_str = seconds_to_hms(eta + eval_eta) logging.info( diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index 7248b211cf72dfeb1d7f2a750dbc4549d34553d0..48a0d17604e7af59377af49967fb8c527f094b09 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -48,6 +48,8 @@ class BaseClassifier(BaseAPI): self.fixed_input_shape = None def build_net(self, mode='train'): + if self.__class__.__name__ == "AlexNet": + assert self.fixed_input_shape is not None, "In AlexNet, input_shape should be defined, e.g. model = paddlex.cls.AlexNet(num_classes=1000, input_shape=[224, 224])" if self.fixed_input_shape is not None: input_shape = [ None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] @@ -427,3 +429,10 @@ class HRNet_W18(BaseClassifier): def __init__(self, num_classes=1000): super(HRNet_W18, self).__init__( model_name='HRNet_W18', num_classes=num_classes) + + +class AlexNet(BaseClassifier): + def __init__(self, num_classes=1000, input_shape=None): + super(AlexNet, self).__init__( + model_name='AlexNet', num_classes=num_classes) + self.fixed_input_shape = input_shape diff --git a/paddlex/cv/models/load_model.py b/paddlex/cv/models/load_model.py index 738f4ff00452d278b3988d9303bb15b0d8885979..5138445afcd8c8fd8f4d0d396703b3280d4e3e51 100644 --- a/paddlex/cv/models/load_model.py +++ b/paddlex/cv/models/load_model.py @@ -41,7 +41,16 @@ def load_model(model_dir, fixed_input_shape=None): if 'model_name' in info['_init_params']: del info['_init_params']['model_name'] model = getattr(paddlex.cv.models, info['Model'])(**info['_init_params']) + model.fixed_input_shape = fixed_input_shape + if '_Attributes' in info: + if 'fixed_input_shape' in info['_Attributes']: + fixed_input_shape = info['_Attributes']['fixed_input_shape'] + if fixed_input_shape is not None: + logging.info("Model already has fixed_input_shape with {}". + format(fixed_input_shape)) + model.fixed_input_shape = fixed_input_shape + if status == "Normal" or \ status == "Prune" or status == "fluid.save": startup_prog = fluid.Program() @@ -88,8 +97,8 @@ def load_model(model_dir, fixed_input_shape=None): model.model_type, info['Transforms'], info['BatchTransforms']) model.eval_transforms = copy.deepcopy(model.test_transforms) else: - model.test_transforms = build_transforms( - model.model_type, info['Transforms'], to_rgb) + model.test_transforms = build_transforms(model.model_type, + info['Transforms'], to_rgb) model.eval_transforms = copy.deepcopy(model.test_transforms) if '_Attributes' in info: diff --git a/paddlex/cv/models/utils/pretrain_weights.py b/paddlex/cv/models/utils/pretrain_weights.py index 3abbdd93d80efd5eb41ead32ac321d758d080104..a7bd78c4bb3e6d1c9d7ffe714aac721873e1ab38 100644 --- a/paddlex/cv/models/utils/pretrain_weights.py +++ b/paddlex/cv/models/utils/pretrain_weights.py @@ -70,6 +70,8 @@ image_pretrain = { 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W60_C_pretrained.tar', 'HRNet_W64': 'https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W64_C_pretrained.tar', + 'AlexNet': + 'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar' } coco_pretrain = { @@ -99,10 +101,12 @@ def get_pretrain_weights(flag, model_type, backbone, save_dir): backbone = 'DetResNet50' assert backbone in image_pretrain, "There is not ImageNet pretrain weights for {}, you may try COCO.".format( backbone) - # url = image_pretrain[backbone] - # fname = osp.split(url)[-1].split('.')[0] - # paddlex.utils.download_and_decompress(url, path=new_save_dir) - # return osp.join(new_save_dir, fname) + + # if backbone == 'AlexNet': + # url = image_pretrain[backbone] + # fname = osp.split(url)[-1].split('.')[0] + # paddlex.utils.download_and_decompress(url, path=new_save_dir) + # return osp.join(new_save_dir, fname) try: hub.download(backbone, save_path=new_save_dir) except Exception as e: diff --git a/paddlex/cv/nets/__init__.py b/paddlex/cv/nets/__init__.py index b1441c59395c2f7788dbab937ab5ad629d4aa940..6e5102a26c9a573db25ad63984dad41c633c987d 100644 --- a/paddlex/cv/nets/__init__.py +++ b/paddlex/cv/nets/__init__.py @@ -24,6 +24,7 @@ from .xception import Xception from .densenet import DenseNet from .shufflenet_v2 import ShuffleNetV2 from .hrnet import HRNet +from .alexnet import AlexNet def resnet18(input, num_classes=1000): @@ -153,3 +154,8 @@ def shufflenetv2(input, num_classes=1000): def hrnet_w18(input, num_classes=1000): model = HRNet(width=18, num_classes=num_classes) return model(input) + + +def alexnet(input, num_classes=1000): + model = AlexNet(num_classes=num_classes) + return model(input) diff --git a/paddlex/cv/nets/alexnet.py b/paddlex/cv/nets/alexnet.py new file mode 100644 index 0000000000000000000000000000000000000000..6770f437d982428cd8d5ed7edb44e00915754139 --- /dev/null +++ b/paddlex/cv/nets/alexnet.py @@ -0,0 +1,170 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + + +class AlexNet(): + def __init__(self, num_classes=1000): + assert num_classes is not None, "In AlextNet, num_classes cannot be None" + self.num_classes = num_classes + + def __call__(self, input): + stdv = 1.0 / math.sqrt(input.shape[1] * 11 * 11) + layer_name = [ + "conv1", "conv2", "conv3", "conv4", "conv5", "fc6", "fc7", "fc8" + ] + conv1 = fluid.layers.conv2d( + input=input, + num_filters=64, + filter_size=11, + stride=4, + padding=2, + groups=1, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[0] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[0] + "_weights")) + pool1 = fluid.layers.pool2d( + input=conv1, + pool_size=3, + pool_stride=2, + pool_padding=0, + pool_type='max') + + stdv = 1.0 / math.sqrt(pool1.shape[1] * 5 * 5) + conv2 = fluid.layers.conv2d( + input=pool1, + num_filters=192, + filter_size=5, + stride=1, + padding=2, + groups=1, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[1] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[1] + "_weights")) + pool2 = fluid.layers.pool2d( + input=conv2, + pool_size=3, + pool_stride=2, + pool_padding=0, + pool_type='max') + + stdv = 1.0 / math.sqrt(pool2.shape[1] * 3 * 3) + conv3 = fluid.layers.conv2d( + input=pool2, + num_filters=384, + filter_size=3, + stride=1, + padding=1, + groups=1, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[2] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[2] + "_weights")) + + stdv = 1.0 / math.sqrt(conv3.shape[1] * 3 * 3) + conv4 = fluid.layers.conv2d( + input=conv3, + num_filters=256, + filter_size=3, + stride=1, + padding=1, + groups=1, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[3] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[3] + "_weights")) + + stdv = 1.0 / math.sqrt(conv4.shape[1] * 3 * 3) + conv5 = fluid.layers.conv2d( + input=conv4, + num_filters=256, + filter_size=3, + stride=1, + padding=1, + groups=1, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[4] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[4] + "_weights")) + pool5 = fluid.layers.pool2d( + input=conv5, + pool_size=3, + pool_stride=2, + pool_padding=0, + pool_type='max') + + drop6 = fluid.layers.dropout(x=pool5, dropout_prob=0.5) + stdv = 1.0 / math.sqrt(drop6.shape[1] * drop6.shape[2] * + drop6.shape[3] * 1.0) + + fc6 = fluid.layers.fc( + input=drop6, + size=4096, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[5] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[5] + "_weights")) + drop7 = fluid.layers.dropout(x=fc6, dropout_prob=0.5) + stdv = 1.0 / math.sqrt(drop7.shape[1] * 1.0) + + fc7 = fluid.layers.fc( + input=drop7, + size=4096, + act='relu', + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[6] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[6] + "_weights")) + + stdv = 1.0 / math.sqrt(fc7.shape[1] * 1.0) + out = fluid.layers.fc( + input=fc7, + size=self.num_classes, + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[7] + "_offset"), + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=layer_name[7] + "_weights")) + return out diff --git a/paddlex/cv/nets/hrnet.py b/paddlex/cv/nets/hrnet.py index 19f9cb336bce66a7dc68d65e316440adf46857e4..a7934d385d4a53fd936410e37d3896fe21cb17ee 100644 --- a/paddlex/cv/nets/hrnet.py +++ b/paddlex/cv/nets/hrnet.py @@ -71,7 +71,7 @@ class HRNet(object): self.end_points = [] return - def net(self, input, class_dim=1000): + def net(self, input): width = self.width channels_2, channels_3, channels_4 = self.channels[width] num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 @@ -125,7 +125,7 @@ class HRNet(object): stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) out = fluid.layers.fc( input=pool, - size=class_dim, + size=self.num_classes, param_attr=ParamAttr( name='fc_weights', initializer=fluid.initializer.Uniform(-stdv, stdv)), diff --git a/paddlex/interpret/core/_session_preparation.py b/paddlex/interpret/core/_session_preparation.py index f75fa2464fe43969ec76c557c43344c0f2ae877f..08eda36fb873c4e5824f8131aca77c7cdc352c22 100644 --- a/paddlex/interpret/core/_session_preparation.py +++ b/paddlex/interpret/core/_session_preparation.py @@ -20,6 +20,7 @@ import numpy as np from paddle.fluid.param_attr import ParamAttr from paddlex.interpret.as_data_reader.readers import preprocess_image + def gen_user_home(): if "HOME" in os.environ: home_path = os.environ["HOME"] @@ -34,10 +35,20 @@ def paddle_get_fc_weights(var_name="fc_0.w_0"): def paddle_resize(extracted_features, outsize): - resized_features = fluid.layers.resize_bilinear(extracted_features, outsize) + resized_features = fluid.layers.resize_bilinear(extracted_features, + outsize) return resized_features +def get_precomputed_normlime_weights(): + root_path = gen_user_home() + root_path = osp.join(root_path, '.paddlex') + h_pre_models = osp.join(root_path, "pre_models") + normlime_weights_file = osp.join( + h_pre_models, "normlime_weights_imagenet_resnet50vc.npy") + return np.load(normlime_weights_file, allow_pickle=True).item() + + def compute_features_for_kmeans(data_content): root_path = gen_user_home() root_path = osp.join(root_path, '.paddlex') @@ -47,6 +58,7 @@ def compute_features_for_kmeans(data_content): os.makedirs(root_path) url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" pdx.utils.download_and_decompress(url, path=root_path) + def conv_bn_layer(input, num_filters, filter_size, @@ -55,7 +67,7 @@ def compute_features_for_kmeans(data_content): act=None, name=None, is_test=True, - global_name=''): + global_name='for_kmeans_'): conv = fluid.layers.conv2d( input=input, num_filters=num_filters, @@ -79,14 +91,14 @@ def compute_features_for_kmeans(data_content): bias_attr=ParamAttr(global_name + bn_name + '_offset'), moving_mean_name=global_name + bn_name + '_mean', moving_variance_name=global_name + bn_name + '_variance', - use_global_stats=is_test - ) + use_global_stats=is_test) startup_prog = fluid.default_startup_program().clone(for_test=True) prog = fluid.Program() with fluid.program_guard(prog, startup_prog): with fluid.unique_name.guard(): - image_op = fluid.data(name='image', shape=[None, 3, 224, 224], dtype='float32') + image_op = fluid.data( + name='image', shape=[None, 3, 224, 224], dtype='float32') conv = conv_bn_layer( input=image_op, @@ -110,7 +122,8 @@ def compute_features_for_kmeans(data_content): act='relu', name='conv1_3') extracted_features = conv - resized_features = fluid.layers.resize_bilinear(extracted_features, image_op.shape[2:]) + resized_features = fluid.layers.resize_bilinear(extracted_features, + image_op.shape[2:]) gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0)) place = fluid.CUDAPlace(gpu_id) @@ -119,7 +132,10 @@ def compute_features_for_kmeans(data_content): exe.run(startup_prog) fluid.io.load_persistables(exe, h_pre_models, prog) - images = preprocess_image(data_content) # transpose to [N, 3, H, W], scaled to [0.0, 1.0] - result = exe.run(prog, fetch_list=[resized_features], feed={'image': images}) + images = preprocess_image( + data_content) # transpose to [N, 3, H, W], scaled to [0.0, 1.0] + result = exe.run(prog, + fetch_list=[resized_features], + feed={'image': images}) return result[0][0] diff --git a/paddlex/interpret/core/interpretation.py b/paddlex/interpret/core/interpretation.py index 72d8c238a2e1817098eefcae18b0a3b56aedeb6b..5b1a5e45b5804acc005a407893c9ceeea8261863 100644 --- a/paddlex/interpret/core/interpretation.py +++ b/paddlex/interpret/core/interpretation.py @@ -20,12 +20,10 @@ class Interpretation(object): """ Base class for all interpretation algorithms. """ - def __init__(self, interpretation_algorithm_name, predict_fn, label_names, **kwargs): - supported_algorithms = { - 'cam': CAM, - 'lime': LIME, - 'normlime': NormLIME - } + + def __init__(self, interpretation_algorithm_name, predict_fn, label_names, + **kwargs): + supported_algorithms = {'cam': CAM, 'lime': LIME, 'normlime': NormLIME} self.algorithm_name = interpretation_algorithm_name.lower() assert self.algorithm_name in supported_algorithms.keys() @@ -33,19 +31,17 @@ class Interpretation(object): # initialization for the interpretation algorithm. self.algorithm = supported_algorithms[self.algorithm_name]( - self.predict_fn, label_names, **kwargs - ) + self.predict_fn, label_names, **kwargs) - def interpret(self, data_, visualization=True, save_to_disk=True, save_dir='./tmp'): + def interpret(self, data_, visualization=True, save_dir='./'): """ Args: data_: data_ can be a path or numpy.ndarray. visualization: whether to show using matplotlib. - save_to_disk: whether to save the figure in local disk. save_dir: dir to save figure if save_to_disk is True. Returns: """ - return self.algorithm.interpret(data_, visualization, save_to_disk, save_dir) + return self.algorithm.interpret(data_, visualization, save_dir) diff --git a/paddlex/interpret/core/interpretation_algorithms.py b/paddlex/interpret/core/interpretation_algorithms.py index afcea8d2d92531590a1aef986014c5bfd792ea5e..a54f46632567f54d934af69dfde64cacea7c5622 100644 --- a/paddlex/interpret/core/interpretation_algorithms.py +++ b/paddlex/interpret/core/interpretation_algorithms.py @@ -23,7 +23,6 @@ from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, lo from paddlex.interpret.as_data_reader.readers import read_image import paddlex.utils.logging as logging - import cv2 @@ -66,25 +65,27 @@ class CAM(object): fc_weights = paddle_get_fc_weights() feature_maps = result[1] - + l = pred_label[0] ln = l if self.label_names is not None: ln = self.label_names[l] prob_str = "%.3f" % (probability[pred_label[0]]) - logging.info("predicted result: {} with probability {}.".format(ln, prob_str)) + logging.info("predicted result: {} with probability {}.".format( + ln, prob_str)) return feature_maps, fc_weights - def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): feature_maps, fc_weights = self.preparation_cam(data_) - cam = get_cam(self.image, feature_maps, fc_weights, self.predicted_label) + cam = get_cam(self.image, feature_maps, fc_weights, + self.predicted_label) - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] - ln = l + ln = l if self.label_names is not None: ln = self.label_names[l] @@ -93,7 +94,8 @@ class CAM(object): ncols = 2 plt.close() - f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + f, axes = plt.subplots( + nrows, ncols, figsize=(psize * ncols, psize * nrows)) for ax in axes.ravel(): ax.axis("off") axes = axes.ravel() @@ -104,8 +106,7 @@ class CAM(object): axes[1].imshow(cam) axes[1].set_title("CAM") - if save_to_disk and save_outdir is not None: - os.makedirs(save_outdir, exist_ok=True) + if save_outdir is not None: save_fig(data_, save_outdir, 'cam') if visualization: @@ -115,7 +116,11 @@ class CAM(object): class LIME(object): - def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50): + def __init__(self, + predict_fn, + label_names, + num_samples=3000, + batch_size=50): """ LIME wrapper. See lime_base.py for the detailed LIME implementation. Args: @@ -154,31 +159,37 @@ class LIME(object): self.predicted_probability = probability[pred_label[0]] self.image = image_show[0] self.labels = pred_label - + l = pred_label[0] ln = l if self.label_names is not None: ln = self.label_names[l] - + prob_str = "%.3f" % (probability[pred_label[0]]) - logging.info("predicted result: {} with probability {}.".format(ln, prob_str)) + logging.info("predicted result: {} with probability {}.".format( + ln, prob_str)) end = time.time() algo = lime_base.LimeImageInterpreter() - interpreter = algo.interpret_instance(self.image, self.predict_fn, self.labels, 0, - num_samples=self.num_samples, batch_size=self.batch_size) + interpreter = algo.interpret_instance( + self.image, + self.predict_fn, + self.labels, + 0, + num_samples=self.num_samples, + batch_size=self.batch_size) self.lime_interpreter = interpreter logging.info('lime time: ' + str(time.time() - end) + 's.') - def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): if self.lime_interpreter is None: self.preparation_lime(data_) - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] - ln = l + ln = l if self.label_names is not None: ln = self.label_names[l] @@ -188,7 +199,8 @@ class LIME(object): ncols = len(weights_choices) plt.close() - f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + f, axes = plt.subplots( + nrows, ncols, figsize=(psize * ncols, psize * nrows)) for ax in axes.ravel(): ax.axis("off") axes = axes.ravel() @@ -196,20 +208,24 @@ class LIME(object): prob_str = "{%.3f}" % (self.predicted_probability) axes[0].set_title("label {}, proba: {}".format(ln, prob_str)) - axes[1].imshow(mark_boundaries(self.image, self.lime_interpreter.segments)) + axes[1].imshow( + mark_boundaries(self.image, self.lime_interpreter.segments)) axes[1].set_title("superpixel segmentation") # LIME visualization for i, w in enumerate(weights_choices): - num_to_show = auto_choose_num_features_to_show(self.lime_interpreter, l, w) + num_to_show = auto_choose_num_features_to_show( + self.lime_interpreter, l, w) temp, mask = self.lime_interpreter.get_image_and_mask( - l, positive_only=False, hide_rest=False, num_features=num_to_show - ) + l, + positive_only=True, + hide_rest=False, + num_features=num_to_show) axes[ncols + i].imshow(mark_boundaries(temp, mask)) - axes[ncols + i].set_title("label {}, first {} superpixels".format(ln, num_to_show)) + axes[ncols + i].set_title( + "label {}, first {} superpixels".format(ln, num_to_show)) - if save_to_disk and save_outdir is not None: - os.makedirs(save_outdir, exist_ok=True) + if save_outdir is not None: save_fig(data_, save_outdir, 'lime', self.num_samples) if visualization: @@ -218,9 +234,196 @@ class LIME(object): return +class NormLIMEStandard(object): + def __init__(self, + predict_fn, + label_names, + num_samples=3000, + batch_size=50, + kmeans_model_for_normlime=None, + normlime_weights=None): + root_path = gen_user_home() + root_path = osp.join(root_path, '.paddlex') + h_pre_models = osp.join(root_path, "pre_models") + if not osp.exists(h_pre_models): + if not osp.exists(root_path): + os.makedirs(root_path) + url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" + pdx.utils.download_and_decompress(url, path=root_path) + h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl") + if kmeans_model_for_normlime is None: + try: + self.kmeans_model = load_kmeans_model(h_pre_models_kmeans) + except: + raise ValueError( + "NormLIME needs the KMeans model, where we provided a default one in " + "pre_models/kmeans_model.pkl.") + else: + logging.debug("Warning: It is *strongly* suggested to use the \ + default KMeans model in pre_models/kmeans_model.pkl. \ + Use another one will change the final result.") + self.kmeans_model = load_kmeans_model(kmeans_model_for_normlime) + + self.num_samples = num_samples + self.batch_size = batch_size + + try: + self.normlime_weights = np.load( + normlime_weights, allow_pickle=True).item() + except: + self.normlime_weights = None + logging.debug( + "Warning: not find the correct precomputed Normlime result.") + + self.predict_fn = predict_fn + + self.labels = None + self.image = None + self.label_names = label_names + + def predict_cluster_labels(self, feature_map, segments): + X = get_feature_for_kmeans(feature_map, segments) + try: + cluster_labels = self.kmeans_model.predict(X) + except AttributeError: + from sklearn.metrics import pairwise_distances_argmin_min + cluster_labels, _ = pairwise_distances_argmin_min( + X, self.kmeans_model.cluster_centers_) + return cluster_labels + + def predict_using_normlime_weights(self, pred_labels, + predicted_cluster_labels): + # global weights + g_weights = {y: [] for y in pred_labels} + for y in pred_labels: + cluster_weights_y = self.normlime_weights.get(y, {}) + g_weights[y] = [(i, cluster_weights_y.get(k, 0.0)) + for i, k in enumerate(predicted_cluster_labels)] + + g_weights[y] = sorted( + g_weights[y], key=lambda x: np.abs(x[1]), reverse=True) + + return g_weights + + def preparation_normlime(self, data_): + self._lime = LIME(self.predict_fn, self.label_names, self.num_samples, + self.batch_size) + self._lime.preparation_lime(data_) + + image_show = read_image(data_) + + self.predicted_label = self._lime.predicted_label + self.predicted_probability = self._lime.predicted_probability + self.image = image_show[0] + self.labels = self._lime.labels + logging.info('performing NormLIME operations ...') + + cluster_labels = self.predict_cluster_labels( + compute_features_for_kmeans(image_show).transpose((1, 2, 0)), + self._lime.lime_interpreter.segments) + + g_weights = self.predict_using_normlime_weights(self.labels, + cluster_labels) + + return g_weights + + def interpret(self, data_, visualization=True, save_outdir=None): + if self.normlime_weights is None: + raise ValueError( + "Not find the correct precomputed NormLIME result. \n" + "\t Try to call compute_normlime_weights() first or load the correct path." + ) + + g_weights = self.preparation_normlime(data_) + lime_weights = self._lime.lime_interpreter.local_weights + + if visualization or save_outdir is not None: + import matplotlib.pyplot as plt + from skimage.segmentation import mark_boundaries + l = self.labels[0] + ln = l + if self.label_names is not None: + ln = self.label_names[l] + + psize = 5 + nrows = 4 + weights_choices = [0.6, 0.7, 0.75, 0.8, 0.85] + nums_to_show = [] + ncols = len(weights_choices) + + plt.close() + f, axes = plt.subplots( + nrows, ncols, figsize=(psize * ncols, psize * nrows)) + for ax in axes.ravel(): + ax.axis("off") + + axes = axes.ravel() + axes[0].imshow(self.image) + prob_str = "{%.3f}" % (self.predicted_probability) + axes[0].set_title("label {}, proba: {}".format(ln, prob_str)) + + axes[1].imshow( + mark_boundaries(self.image, + self._lime.lime_interpreter.segments)) + axes[1].set_title("superpixel segmentation") + + # LIME visualization + for i, w in enumerate(weights_choices): + num_to_show = auto_choose_num_features_to_show( + self._lime.lime_interpreter, l, w) + nums_to_show.append(num_to_show) + temp, mask = self._lime.lime_interpreter.get_image_and_mask( + l, + positive_only=False, + hide_rest=False, + num_features=num_to_show) + axes[ncols + i].imshow(mark_boundaries(temp, mask)) + axes[ncols + i].set_title("LIME: first {} superpixels".format( + num_to_show)) + + # NormLIME visualization + self._lime.lime_interpreter.local_weights = g_weights + for i, num_to_show in enumerate(nums_to_show): + temp, mask = self._lime.lime_interpreter.get_image_and_mask( + l, + positive_only=False, + hide_rest=False, + num_features=num_to_show) + axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask)) + axes[ncols * 2 + i].set_title( + "NormLIME: first {} superpixels".format(num_to_show)) + + # NormLIME*LIME visualization + combined_weights = combine_normlime_and_lime(lime_weights, + g_weights) + self._lime.lime_interpreter.local_weights = combined_weights + for i, num_to_show in enumerate(nums_to_show): + temp, mask = self._lime.lime_interpreter.get_image_and_mask( + l, + positive_only=False, + hide_rest=False, + num_features=num_to_show) + axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask)) + axes[ncols * 3 + i].set_title( + "Combined: first {} superpixels".format(num_to_show)) + + self._lime.lime_interpreter.local_weights = lime_weights + + if save_outdir is not None: + save_fig(data_, save_outdir, 'normlime', self.num_samples) + + if visualization: + plt.show() + + class NormLIME(object): - def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50, - kmeans_model_for_normlime=None, normlime_weights=None): + def __init__(self, + predict_fn, + label_names, + num_samples=3000, + batch_size=50, + kmeans_model_for_normlime=None, + normlime_weights=None): root_path = gen_user_home() root_path = osp.join(root_path, '.paddlex') h_pre_models = osp.join(root_path, "pre_models") @@ -234,8 +437,9 @@ class NormLIME(object): try: self.kmeans_model = load_kmeans_model(h_pre_models_kmeans) except: - raise ValueError("NormLIME needs the KMeans model, where we provided a default one in " - "pre_models/kmeans_model.pkl.") + raise ValueError( + "NormLIME needs the KMeans model, where we provided a default one in " + "pre_models/kmeans_model.pkl.") else: logging.debug("Warning: It is *strongly* suggested to use the \ default KMeans model in pre_models/kmeans_model.pkl. \ @@ -246,10 +450,12 @@ class NormLIME(object): self.batch_size = batch_size try: - self.normlime_weights = np.load(normlime_weights, allow_pickle=True).item() + self.normlime_weights = np.load( + normlime_weights, allow_pickle=True).item() except: self.normlime_weights = None - logging.debug("Warning: not find the correct precomputed Normlime result.") + logging.debug( + "Warning: not find the correct precomputed Normlime result.") self.predict_fn = predict_fn @@ -263,30 +469,27 @@ class NormLIME(object): cluster_labels = self.kmeans_model.predict(X) except AttributeError: from sklearn.metrics import pairwise_distances_argmin_min - cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_) + cluster_labels, _ = pairwise_distances_argmin_min( + X, self.kmeans_model.cluster_centers_) return cluster_labels - def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels): + def predict_using_normlime_weights(self, pred_labels, + predicted_cluster_labels): # global weights g_weights = {y: [] for y in pred_labels} for y in pred_labels: cluster_weights_y = self.normlime_weights.get(y, {}) - g_weights[y] = [ - (i, cluster_weights_y.get(k, 0.0)) for i, k in enumerate(predicted_cluster_labels) - ] + g_weights[y] = [(i, cluster_weights_y.get(k, 0.0)) + for i, k in enumerate(predicted_cluster_labels)] - g_weights[y] = sorted(g_weights[y], - key=lambda x: np.abs(x[1]), reverse=True) + g_weights[y] = sorted( + g_weights[y], key=lambda x: np.abs(x[1]), reverse=True) return g_weights def preparation_normlime(self, data_): - self._lime = LIME( - self.predict_fn, - self.label_names, - self.num_samples, - self.batch_size - ) + self._lime = LIME(self.predict_fn, self.label_names, self.num_samples, + self.batch_size) self._lime.preparation_lime(data_) image_show = read_image(data_) @@ -298,22 +501,25 @@ class NormLIME(object): logging.info('performing NormLIME operations ...') cluster_labels = self.predict_cluster_labels( - compute_features_for_kmeans(image_show).transpose((1, 2, 0)), self._lime.lime_interpreter.segments - ) + compute_features_for_kmeans(image_show).transpose((1, 2, 0)), + self._lime.lime_interpreter.segments) - g_weights = self.predict_using_normlime_weights(self.labels, cluster_labels) + g_weights = self.predict_using_normlime_weights(self.labels, + cluster_labels) return g_weights - def interpret(self, data_, visualization=True, save_to_disk=True, save_outdir=None): + def interpret(self, data_, visualization=True, save_outdir=None): if self.normlime_weights is None: - raise ValueError("Not find the correct precomputed NormLIME result. \n" - "\t Try to call compute_normlime_weights() first or load the correct path.") + raise ValueError( + "Not find the correct precomputed NormLIME result. \n" + "\t Try to call compute_normlime_weights() first or load the correct path." + ) g_weights = self.preparation_normlime(data_) lime_weights = self._lime.lime_interpreter.local_weights - if visualization or save_to_disk: + if visualization or save_outdir is not None: import matplotlib.pyplot as plt from skimage.segmentation import mark_boundaries l = self.labels[0] @@ -328,7 +534,8 @@ class NormLIME(object): ncols = len(weights_choices) plt.close() - f, axes = plt.subplots(nrows, ncols, figsize=(psize * ncols, psize * nrows)) + f, axes = plt.subplots( + nrows, ncols, figsize=(psize * ncols, psize * nrows)) for ax in axes.ravel(): ax.axis("off") @@ -337,64 +544,83 @@ class NormLIME(object): prob_str = "{%.3f}" % (self.predicted_probability) axes[0].set_title("label {}, proba: {}".format(ln, prob_str)) - axes[1].imshow(mark_boundaries(self.image, self._lime.lime_interpreter.segments)) + axes[1].imshow( + mark_boundaries(self.image, + self._lime.lime_interpreter.segments)) axes[1].set_title("superpixel segmentation") # LIME visualization for i, w in enumerate(weights_choices): - num_to_show = auto_choose_num_features_to_show(self._lime.lime_interpreter, l, w) + num_to_show = auto_choose_num_features_to_show( + self._lime.lime_interpreter, l, w) nums_to_show.append(num_to_show) temp, mask = self._lime.lime_interpreter.get_image_and_mask( - l, positive_only=False, hide_rest=False, num_features=num_to_show - ) + l, + positive_only=True, + hide_rest=False, + num_features=num_to_show) axes[ncols + i].imshow(mark_boundaries(temp, mask)) - axes[ncols + i].set_title("LIME: first {} superpixels".format(num_to_show)) + axes[ncols + i].set_title("LIME: first {} superpixels".format( + num_to_show)) # NormLIME visualization self._lime.lime_interpreter.local_weights = g_weights for i, num_to_show in enumerate(nums_to_show): temp, mask = self._lime.lime_interpreter.get_image_and_mask( - l, positive_only=False, hide_rest=False, num_features=num_to_show - ) + l, + positive_only=True, + hide_rest=False, + num_features=num_to_show) axes[ncols * 2 + i].imshow(mark_boundaries(temp, mask)) - axes[ncols * 2 + i].set_title("NormLIME: first {} superpixels".format(num_to_show)) + axes[ncols * 2 + i].set_title( + "NormLIME: first {} superpixels".format(num_to_show)) # NormLIME*LIME visualization - combined_weights = combine_normlime_and_lime(lime_weights, g_weights) + combined_weights = combine_normlime_and_lime(lime_weights, + g_weights) + self._lime.lime_interpreter.local_weights = combined_weights for i, num_to_show in enumerate(nums_to_show): temp, mask = self._lime.lime_interpreter.get_image_and_mask( - l, positive_only=False, hide_rest=False, num_features=num_to_show - ) + l, + positive_only=True, + hide_rest=False, + num_features=num_to_show) axes[ncols * 3 + i].imshow(mark_boundaries(temp, mask)) - axes[ncols * 3 + i].set_title("Combined: first {} superpixels".format(num_to_show)) + axes[ncols * 3 + i].set_title( + "Combined: first {} superpixels".format(num_to_show)) self._lime.lime_interpreter.local_weights = lime_weights - if save_to_disk and save_outdir is not None: - os.makedirs(save_outdir, exist_ok=True) + if save_outdir is not None: save_fig(data_, save_outdir, 'normlime', self.num_samples) if visualization: plt.show() -def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show): +def auto_choose_num_features_to_show(lime_interpreter, label, + percentage_to_show): segments = lime_interpreter.segments lime_weights = lime_interpreter.local_weights[label] - num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[1] // len(np.unique(segments)) // 8 + num_pixels_threshold_in_a_sp = segments.shape[0] * segments.shape[ + 1] // len(np.unique(segments)) // 8 # l1 norm with filtered weights. - used_weights = [(tuple_w[0], tuple_w[1]) for i, tuple_w in enumerate(lime_weights) if tuple_w[1] > 0] + used_weights = [(tuple_w[0], tuple_w[1]) + for i, tuple_w in enumerate(lime_weights) + if tuple_w[1] > 0] norm = np.sum([tuple_w[1] for i, tuple_w in enumerate(used_weights)]) - normalized_weights = [(tuple_w[0], tuple_w[1] / norm) for i, tuple_w in enumerate(lime_weights)] + normalized_weights = [(tuple_w[0], tuple_w[1] / norm) + for i, tuple_w in enumerate(lime_weights)] a = 0.0 n = 0 for i, tuple_w in enumerate(normalized_weights): if tuple_w[1] < 0: continue - if len(np.where(segments == tuple_w[0])[0]) < num_pixels_threshold_in_a_sp: + if len(np.where(segments == tuple_w[0])[ + 0]) < num_pixels_threshold_in_a_sp: continue a += tuple_w[1] @@ -406,12 +632,18 @@ def auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show return 5 if n == 0: - return auto_choose_num_features_to_show(lime_interpreter, label, percentage_to_show-0.1) + return auto_choose_num_features_to_show(lime_interpreter, label, + percentage_to_show - 0.1) return n -def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam_max=None): +def get_cam(image_show, + feature_maps, + fc_weights, + label_index, + cam_min=None, + cam_max=None): _, nc, h, w = feature_maps.shape cam = feature_maps * fc_weights[:, label_index].reshape(1, nc, 1, 1) @@ -425,7 +657,8 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam cam = cam - cam_min cam = cam / cam_max cam = np.uint8(255 * cam) - cam_img = cv2.resize(cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR) + cam_img = cv2.resize( + cam, image_show.shape[0:2], interpolation=cv2.INTER_LINEAR) heatmap = cv2.applyColorMap(np.uint8(255 * cam_img), cv2.COLORMAP_JET) heatmap = np.float32(heatmap) @@ -437,34 +670,11 @@ def get_cam(image_show, feature_maps, fc_weights, label_index, cam_min=None, cam def save_fig(data_, save_outdir, algorithm_name, num_samples=3000): import matplotlib.pyplot as plt - if isinstance(data_, str): - if algorithm_name == 'cam': - f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1]) - else: - f_out = "{}_{}_s{}.png".format(algorithm_name, data_.split('/')[-1], num_samples) - plt.savefig( - os.path.join(save_outdir, f_out) - ) + if algorithm_name == 'cam': + f_out = "{}_{}.png".format(algorithm_name, data_.split('/')[-1]) else: - n = 0 - if algorithm_name == 'cam': - f_out = 'cam-{}.png'.format(n) - else: - f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n) - while os.path.exists( - os.path.join(save_outdir, f_out) - ): - n += 1 - if algorithm_name == 'cam': - f_out = 'cam-{}.png'.format(n) - else: - f_out = '{}_s{}-{}.png'.format(algorithm_name, num_samples, n) - continue - plt.savefig( - os.path.join( - save_outdir, f_out - ) - ) - logging.info('The image of intrepretation result save in {}'.format(os.path.join( - save_outdir, f_out - ))) + f_out = "{}_{}_s{}.png".format(save_outdir, algorithm_name, + num_samples) + + plt.savefig(f_out) + logging.info('The image of intrepretation result save in {}'.format(f_out)) diff --git a/paddlex/interpret/core/lime_base.py b/paddlex/interpret/core/lime_base.py index 3d3bd96d0e7b5ffb0de2d2f8156a03021cfad312..d7b44016ae41656c41db25572133e5a6cfc57675 100644 --- a/paddlex/interpret/core/lime_base.py +++ b/paddlex/interpret/core/lime_base.py @@ -27,7 +27,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. The code in this file (lime_base.py) is modified from https://github.com/marcotcr/lime. """ - import numpy as np import scipy as sp @@ -39,10 +38,8 @@ import paddlex.utils.logging as logging class LimeBase(object): """Class for learning a locally linear sparse model from perturbed data""" - def __init__(self, - kernel_fn, - verbose=False, - random_state=None): + + def __init__(self, kernel_fn, verbose=False, random_state=None): """Init function Args: @@ -72,15 +69,14 @@ class LimeBase(object): """ from sklearn.linear_model import lars_path x_vector = weighted_data - alphas, _, coefs = lars_path(x_vector, - weighted_labels, - method='lasso', - verbose=False) + alphas, _, coefs = lars_path( + x_vector, weighted_labels, method='lasso', verbose=False) return alphas, coefs def forward_selection(self, data, labels, weights, num_features): """Iteratively adds features to the model""" - clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) + clf = Ridge( + alpha=0, fit_intercept=True, random_state=self.random_state) used_features = [] for _ in range(min(num_features, data.shape[1])): max_ = -100000000 @@ -88,11 +84,13 @@ class LimeBase(object): for feature in range(data.shape[1]): if feature in used_features: continue - clf.fit(data[:, used_features + [feature]], labels, + clf.fit(data[:, used_features + [feature]], + labels, sample_weight=weights) - score = clf.score(data[:, used_features + [feature]], - labels, - sample_weight=weights) + score = clf.score( + data[:, used_features + [feature]], + labels, + sample_weight=weights) if score > max_: best = feature max_ = score @@ -108,8 +106,8 @@ class LimeBase(object): elif method == 'forward_selection': return self.forward_selection(data, labels, weights, num_features) elif method == 'highest_weights': - clf = Ridge(alpha=0.01, fit_intercept=True, - random_state=self.random_state) + clf = Ridge( + alpha=0.01, fit_intercept=True, random_state=self.random_state) clf.fit(data, labels, sample_weight=weights) coef = clf.coef_ @@ -125,7 +123,8 @@ class LimeBase(object): nnz_indexes = argsort_data[::-1] indices = weighted_data.indices[nnz_indexes] num_to_pad = num_features - sdata - indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype))) + indices = np.concatenate((indices, np.zeros( + num_to_pad, dtype=indices.dtype))) indices_set = set(indices) pad_counter = 0 for i in range(data.shape[1]): @@ -135,7 +134,8 @@ class LimeBase(object): if pad_counter >= num_to_pad: break else: - nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] + nnz_indexes = argsort_data[sdata - num_features:sdata][:: + -1] indices = weighted_data.indices[nnz_indexes] return indices else: @@ -146,13 +146,13 @@ class LimeBase(object): reverse=True) return np.array([x[0] for x in feature_weights[:num_features]]) elif method == 'lasso_path': - weighted_data = ((data - np.average(data, axis=0, weights=weights)) - * np.sqrt(weights[:, np.newaxis])) - weighted_labels = ((labels - np.average(labels, weights=weights)) - * np.sqrt(weights)) + weighted_data = ((data - np.average( + data, axis=0, weights=weights)) * + np.sqrt(weights[:, np.newaxis])) + weighted_labels = ((labels - np.average( + labels, weights=weights)) * np.sqrt(weights)) nonzero = range(weighted_data.shape[1]) - _, coefs = self.generate_lars_path(weighted_data, - weighted_labels) + _, coefs = self.generate_lars_path(weighted_data, weighted_labels) for i in range(len(coefs.T) - 1, 0, -1): nonzero = coefs.T[i].nonzero()[0] if len(nonzero) <= num_features: @@ -164,8 +164,8 @@ class LimeBase(object): n_method = 'forward_selection' else: n_method = 'highest_weights' - return self.feature_selection(data, labels, weights, - num_features, n_method) + return self.feature_selection(data, labels, weights, num_features, + n_method) def interpret_instance_with_data(self, neighborhood_data, @@ -214,30 +214,31 @@ class LimeBase(object): weights = self.kernel_fn(distances) labels_column = neighborhood_labels[:, label] used_features = self.feature_selection(neighborhood_data, - labels_column, - weights, - num_features, - feature_selection) + labels_column, weights, + num_features, feature_selection) if model_regressor is None: - model_regressor = Ridge(alpha=1, fit_intercept=True, - random_state=self.random_state) + model_regressor = Ridge( + alpha=1, fit_intercept=True, random_state=self.random_state) easy_model = model_regressor easy_model.fit(neighborhood_data[:, used_features], - labels_column, sample_weight=weights) + labels_column, + sample_weight=weights) prediction_score = easy_model.score( neighborhood_data[:, used_features], - labels_column, sample_weight=weights) + labels_column, + sample_weight=weights) - local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) + local_pred = easy_model.predict(neighborhood_data[0, used_features] + .reshape(1, -1)) if self.verbose: logging.info('Intercept' + str(easy_model.intercept_)) logging.info('Prediction_local' + str(local_pred)) logging.info('Right:' + str(neighborhood_labels[0, label])) - return (easy_model.intercept_, - sorted(zip(used_features, easy_model.coef_), - key=lambda x: np.abs(x[1]), reverse=True), - prediction_score, local_pred) + return (easy_model.intercept_, sorted( + zip(used_features, easy_model.coef_), + key=lambda x: np.abs(x[1]), + reverse=True), prediction_score, local_pred) class ImageInterpretation(object): @@ -254,8 +255,13 @@ class ImageInterpretation(object): self.local_weights = {} self.local_pred = None - def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, - num_features=5, min_weight=0.): + def get_image_and_mask(self, + label, + positive_only=True, + negative_only=False, + hide_rest=False, + num_features=5, + min_weight=0.): """Init function. Args: @@ -279,7 +285,9 @@ class ImageInterpretation(object): if label not in self.local_weights: raise KeyError('Label not in interpretation') if positive_only & negative_only: - raise ValueError("Positive_only and negative_only cannot be true at the same time.") + raise ValueError( + "Positive_only and negative_only cannot be true at the same time." + ) segments = self.segments image = self.image local_weights_label = self.local_weights[label] @@ -289,14 +297,20 @@ class ImageInterpretation(object): else: temp = self.image.copy() if positive_only: - fs = [x[0] for x in local_weights_label - if x[1] > 0 and x[1] > min_weight][:num_features] + fs = [ + x[0] for x in local_weights_label + if x[1] > 0 and x[1] > min_weight + ][:num_features] if negative_only: - fs = [x[0] for x in local_weights_label - if x[1] < 0 and abs(x[1]) > min_weight][:num_features] + fs = [ + x[0] for x in local_weights_label + if x[1] < 0 and abs(x[1]) > min_weight + ][:num_features] if positive_only or negative_only: + c = 1 if positive_only else 0 for f in fs: - temp[segments == f] = image[segments == f].copy() + temp[segments == f] = [0, 255, 0] + # temp[segments == f, c] = np.max(image) mask[segments == f] = 1 return temp, mask else: @@ -330,8 +344,11 @@ class ImageInterpretation(object): temp = np.zeros_like(image) weight_max = abs(local_weights_label[0][1]) - local_weights_label = [(f, w/weight_max) for f, w in local_weights_label] - local_weights_label = sorted(local_weights_label, key=lambda x: x[1], reverse=True) # negatives are at last. + local_weights_label = [(f, w / weight_max) + for f, w in local_weights_label] + local_weights_label = sorted( + local_weights_label, key=lambda x: x[1], + reverse=True) # negatives are at last. cmaps = cm.get_cmap('Spectral') colors = cmaps(np.linspace(0, 1, len(local_weights_label))) @@ -354,8 +371,12 @@ class LimeImageInterpreter(object): feature that is 1 when the value is the same as the instance being interpreted.""" - def __init__(self, kernel_width=.25, kernel=None, verbose=False, - feature_selection='auto', random_state=None): + def __init__(self, + kernel_width=.25, + kernel=None, + verbose=False, + feature_selection='auto', + random_state=None): """Init function. Args: @@ -377,22 +398,27 @@ class LimeImageInterpreter(object): kernel_width = float(kernel_width) if kernel is None: + def kernel(d, kernel_width): - return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) + return np.sqrt(np.exp(-(d**2) / kernel_width**2)) kernel_fn = partial(kernel, kernel_width=kernel_width) self.random_state = check_random_state(random_state) self.feature_selection = feature_selection - self.base = LimeBase(kernel_fn, verbose, random_state=self.random_state) + self.base = LimeBase( + kernel_fn, verbose, random_state=self.random_state) - def interpret_instance(self, image, classifier_fn, labels=(1,), + def interpret_instance(self, + image, + classifier_fn, + labels=(1, ), hide_color=None, - num_features=100000, num_samples=1000, + num_features=100000, + num_samples=1000, batch_size=10, distance_metric='cosine', - model_regressor=None - ): + model_regressor=None): """Generates interpretations for a prediction. First, we generate neighborhood data by randomly perturbing features @@ -435,6 +461,7 @@ class LimeImageInterpreter(object): self.segments = segments fudged_image = image.copy() + # global_mean = np.mean(image, (0, 1)) if hide_color is None: # if no hide_color, use the mean for x in np.unique(segments): @@ -461,24 +488,30 @@ class LimeImageInterpreter(object): top = labels - data, labels = self.data_labels(image, fudged_image, segments, - classifier_fn, num_samples, - batch_size=batch_size) + data, labels = self.data_labels( + image, + fudged_image, + segments, + classifier_fn, + num_samples, + batch_size=batch_size) distances = sklearn.metrics.pairwise_distances( - data, - data[0].reshape(1, -1), - metric=distance_metric - ).ravel() + data, data[0].reshape(1, -1), metric=distance_metric).ravel() interpretation_image = ImageInterpretation(image, segments) for label in top: (interpretation_image.intercept[label], interpretation_image.local_weights[label], - interpretation_image.score, interpretation_image.local_pred) = self.base.interpret_instance_with_data( - data, labels, distances, label, num_features, - model_regressor=model_regressor, - feature_selection=self.feature_selection) + interpretation_image.score, interpretation_image.local_pred + ) = self.base.interpret_instance_with_data( + data, + labels, + distances, + label, + num_features, + model_regressor=model_regressor, + feature_selection=self.feature_selection) return interpretation_image def data_labels(self, @@ -511,6 +544,9 @@ class LimeImageInterpreter(object): labels = [] data[0, :] = 1 imgs = [] + + logging.info("Computing LIME.", use_color=True) + for row in tqdm.tqdm(data): temp = copy.deepcopy(image) zeros = np.where(row == 0)[0] diff --git a/paddlex/interpret/core/normlime_base.py b/paddlex/interpret/core/normlime_base.py index 3b3a94212ded51d1300b9ae78f4cdab0e1589903..471078129cdd96df10ae0af1ced39ccf344c7564 100644 --- a/paddlex/interpret/core/normlime_base.py +++ b/paddlex/interpret/core/normlime_base.py @@ -16,6 +16,7 @@ import os import os.path as osp import numpy as np import glob +import tqdm from paddlex.interpret.as_data_reader.readers import read_image import paddlex.utils.logging as logging @@ -38,18 +39,24 @@ def combine_normlime_and_lime(lime_weights, g_weights): for y in pred_labels: normlized_lime_weights_y = lime_weights[y] - lime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_lime_weights_y} + lime_weights_dict = { + tuple_w[0]: tuple_w[1] + for tuple_w in normlized_lime_weights_y + } normlized_g_weight_y = g_weights[y] - normlime_weights_dict = {tuple_w[0]: tuple_w[1] for tuple_w in normlized_g_weight_y} + normlime_weights_dict = { + tuple_w[0]: tuple_w[1] + for tuple_w in normlized_g_weight_y + } combined_weights[y] = [ (seg_k, lime_weights_dict[seg_k] * normlime_weights_dict[seg_k]) for seg_k in lime_weights_dict.keys() ] - combined_weights[y] = sorted(combined_weights[y], - key=lambda x: np.abs(x[1]), reverse=True) + combined_weights[y] = sorted( + combined_weights[y], key=lambda x: np.abs(x[1]), reverse=True) return combined_weights @@ -67,7 +74,8 @@ def centroid_using_superpixels(features, segments): regions = regionprops(segments + 1) one_list = np.zeros((len(np.unique(segments)), features.shape[2])) for i, r in enumerate(regions): - one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + 0.5), :] + one_list[i] = features[int(r.centroid[0] + 0.5), int(r.centroid[1] + + 0.5), :] return one_list @@ -80,30 +88,39 @@ def get_feature_for_kmeans(feature_map, segments): return x -def precompute_normlime_weights(list_data_, predict_fn, num_samples=3000, batch_size=50, save_dir='./tmp'): +def precompute_normlime_weights(list_data_, + predict_fn, + num_samples=3000, + batch_size=50, + save_dir='./tmp'): # save lime weights and kmeans cluster labels - precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir) + precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, + save_dir) # load precomputed results, compute normlime weights and save. - fname_list = glob.glob(os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples))) + fname_list = glob.glob( + os.path.join(save_dir, 'lime_weights_s{}*.npy'.format(num_samples))) return compute_normlime_weights(fname_list, save_dir, num_samples) -def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, cluster_labels, save_path): +def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, + cluster_labels, save_path): lime_weights = {} for label in image_pred_labels: lime_weights[label] = lime_all_weights[label] for_normlime_weights = { - 'lime_weights': lime_weights, # a dict: class_label: (seg_label, weight) + 'lime_weights': + lime_weights, # a dict: class_label: (seg_label, weight) 'cluster': cluster_labels # a list with segments as indices. } np.save(save_path, for_normlime_weights) -def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir): +def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, + save_dir): root_path = gen_user_home() root_path = osp.join(root_path, '.paddlex') h_pre_models = osp.join(root_path, "pre_models") @@ -117,17 +134,24 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav for data_index, each_data_ in enumerate(list_data_): if isinstance(each_data_, str): - save_path = "lime_weights_s{}_{}.npy".format(num_samples, each_data_.split('/')[-1].split('.')[0]) + save_path = "lime_weights_s{}_{}.npy".format( + num_samples, each_data_.split('/')[-1].split('.')[0]) save_path = os.path.join(save_dir, save_path) else: - save_path = "lime_weights_s{}_{}.npy".format(num_samples, data_index) + save_path = "lime_weights_s{}_{}.npy".format(num_samples, + data_index) save_path = os.path.join(save_dir, save_path) if os.path.exists(save_path): - logging.info(save_path + ' exists, not computing this one.', use_color=True) + logging.info( + save_path + ' exists, not computing this one.', use_color=True) continue - img_file_name = each_data_ if isinstance(each_data_, str) else data_index - logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True) + img_file_name = each_data_ if isinstance(each_data_, + str) else data_index + logging.info( + 'processing ' + img_file_name + ' [{}/{}]'.format(data_index, + len(list_data_)), + use_color=True) image_show = read_image(each_data_) result = predict_fn(image_show) @@ -156,32 +180,38 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav pred_label = pred_label[:top_k] algo = lime_base.LimeImageInterpreter() - interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0, - num_samples=num_samples, batch_size=batch_size) - - X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments) + interpreter = algo.interpret_instance( + image_show[0], + predict_fn, + pred_label, + 0, + num_samples=num_samples, + batch_size=batch_size) + + X = get_feature_for_kmeans( + compute_features_for_kmeans(image_show).transpose((1, 2, 0)), + interpreter.segments) try: cluster_labels = kmeans_model.predict(X) except AttributeError: from sklearn.metrics import pairwise_distances_argmin_min - cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_) + cluster_labels, _ = pairwise_distances_argmin_min( + X, kmeans_model.cluster_centers_) save_one_lime_predict_and_kmean_labels( - interpreter.local_weights, pred_label, - cluster_labels, - save_path - ) + interpreter.local_weights, pred_label, cluster_labels, save_path) def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): normlime_weights_all_labels = {} - + for f in a_list_lime_fnames: try: lime_weights_and_cluster = np.load(f, allow_pickle=True).item() lime_weights = lime_weights_and_cluster['lime_weights'] cluster = lime_weights_and_cluster['cluster'] except: - logging.info('When loading precomputed LIME result, skipping' + str(f)) + logging.info('When loading precomputed LIME result, skipping' + + str(f)) continue logging.info('Loading precomputed LIME result,' + str(f)) pred_labels = lime_weights.keys() @@ -203,10 +233,12 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): for y in normlime_weights_all_labels: normlime_weights = normlime_weights_all_labels.get(y, {}) for k in normlime_weights: - normlime_weights[k] = sum(normlime_weights[k]) / len(normlime_weights[k]) + normlime_weights[k] = sum(normlime_weights[k]) / len( + normlime_weights[k]) # check normlime - if len(normlime_weights_all_labels.keys()) < max(normlime_weights_all_labels.keys()) + 1: + if len(normlime_weights_all_labels.keys()) < max( + normlime_weights_all_labels.keys()) + 1: logging.info( "\n" + \ "Warning: !!! \n" + \ @@ -218,17 +250,166 @@ def compute_normlime_weights(a_list_lime_fnames, save_dir, lime_num_samples): ) n = 0 - f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n) - while os.path.exists( - os.path.join(save_dir, f_out) - ): + f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format( + lime_num_samples, len(a_list_lime_fnames), n) + while os.path.exists(os.path.join(save_dir, f_out)): n += 1 - f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format(lime_num_samples, len(a_list_lime_fnames), n) + f_out = 'normlime_weights_s{}_samples_{}-{}.npy'.format( + lime_num_samples, len(a_list_lime_fnames), n) continue - np.save( - os.path.join(save_dir, f_out), - normlime_weights_all_labels - ) + np.save(os.path.join(save_dir, f_out), normlime_weights_all_labels) return os.path.join(save_dir, f_out) + +def precompute_global_classifier(dataset, + predict_fn, + save_path, + batch_size=50, + max_num_samples=1000): + from sklearn.linear_model import LogisticRegression + + root_path = gen_user_home() + root_path = osp.join(root_path, '.paddlex') + h_pre_models = osp.join(root_path, "pre_models") + if not osp.exists(h_pre_models): + if not osp.exists(root_path): + os.makedirs(root_path) + url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" + pdx.utils.download_and_decompress(url, path=root_path) + h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl") + kmeans_model = load_kmeans_model(h_pre_models_kmeans) + + image_list = [] + for item in dataset.file_list: + image_list.append(item[0]) + + x_data = [] + y_labels = [] + + num_features = len(kmeans_model.cluster_centers_) + + logging.info( + "Initialization for NormLIME: Computing each sample in the test list.", + use_color=True) + + for each_data_ in tqdm.tqdm(image_list): + x_data_i = np.zeros((num_features)) + image_show = read_image(each_data_) + result = predict_fn(image_show) + result = result[0] # only one image here. + c = compute_features_for_kmeans(image_show).transpose((1, 2, 0)) + + segments = np.zeros((image_show.shape[1], image_show.shape[2]), + np.int32) + num_blocks = 10 + height_per_i = segments.shape[0] // num_blocks + 1 + width_per_i = segments.shape[1] // num_blocks + 1 + + for i in range(segments.shape[0]): + for j in range(segments.shape[1]): + segments[i, + j] = i // height_per_i * num_blocks + j // width_per_i + + # segments = quickshift(image_show[0], sigma=1) + X = get_feature_for_kmeans(c, segments) + + try: + cluster_labels = kmeans_model.predict(X) + except AttributeError: + from sklearn.metrics import pairwise_distances_argmin_min + cluster_labels, _ = pairwise_distances_argmin_min( + X, kmeans_model.cluster_centers_) + + for c in cluster_labels: + x_data_i[c] = 1 + + # x_data_i /= len(cluster_labels) + + pred_y_i = np.argmax(result) + y_labels.append(pred_y_i) + x_data.append(x_data_i) + + if len(np.unique(y_labels)) < 2: + logging.info("Warning: The test samples in the dataset is limited.\n \ + NormLIME may have no effect on the results.\n \ + Try to add more test samples, or see the results of LIME.") + num_classes = np.max(np.unique(y_labels)) + 1 + normlime_weights_all_labels = {} + for class_index in range(num_classes): + w = np.ones((num_features)) / num_features + normlime_weights_all_labels[class_index] = { + i: wi + for i, wi in enumerate(w) + } + logging.info("Saving the computed normlime_weights in {}".format( + save_path)) + + np.save(save_path, normlime_weights_all_labels) + return save_path + + clf = LogisticRegression(multi_class='multinomial', max_iter=1000) + clf.fit(x_data, y_labels) + + num_classes = np.max(np.unique(y_labels)) + 1 + normlime_weights_all_labels = {} + + if len(y_labels) / len(np.unique(y_labels)) < 3: + logging.info("Warning: The test samples in the dataset is limited.\n \ + NormLIME may have no effect on the results.\n \ + Try to add more test samples, or see the results of LIME.") + + if len(np.unique(y_labels)) == 2: + # binary: clf.coef_ has shape of [1, num_features] + for class_index in range(num_classes): + if class_index not in clf.classes_: + w = np.ones((num_features)) / num_features + normlime_weights_all_labels[class_index] = { + i: wi + for i, wi in enumerate(w) + } + continue + + if clf.classes_[0] == class_index: + w = -clf.coef_[0] + else: + w = clf.coef_[0] + + # softmax + w = w - np.max(w) + exp_w = np.exp(w * 10) + w = exp_w / np.sum(exp_w) + + normlime_weights_all_labels[class_index] = { + i: wi + for i, wi in enumerate(w) + } + else: + # clf.coef_ has shape of [len(np.unique(y_labels)), num_features] + for class_index in range(num_classes): + if class_index not in clf.classes_: + w = np.ones((num_features)) / num_features + normlime_weights_all_labels[class_index] = { + i: wi + for i, wi in enumerate(w) + } + continue + + coef_class_index = np.where(clf.classes_ == class_index)[0][0] + w = clf.coef_[coef_class_index] + + # softmax + w = w - np.max(w) + exp_w = np.exp(w * 10) + w = exp_w / np.sum(exp_w) + + normlime_weights_all_labels[class_index] = { + i: wi + for i, wi in enumerate(w) + } + + logging.info("Saving the computed normlime_weights in {}".format( + save_path)) + np.save(save_path, normlime_weights_all_labels) + + return save_path diff --git a/paddlex/interpret/interpretation_predict.py b/paddlex/interpret/interpretation_predict.py index 198f949ac7f13117fb51b7240d532eabf1c669eb..31b3b47e86613f62ba1c63b4ba2041357cc6bdc7 100644 --- a/paddlex/interpret/interpretation_predict.py +++ b/paddlex/interpret/interpretation_predict.py @@ -13,17 +13,26 @@ # limitations under the License. import numpy as np +import cv2 +import copy + def interpretation_predict(model, images): - model.arrange_transforms( - transforms=model.test_transforms, mode='test') + images = images.astype('float32') + model.arrange_transforms(transforms=model.test_transforms, mode='test') + tmp_transforms = copy.deepcopy(model.test_transforms.transforms) + model.test_transforms.transforms = model.test_transforms.transforms[-2:] + new_imgs = [] for i in range(images.shape[0]): - img = images[i] - new_imgs.append(model.test_transforms(img)[0]) + images[i] = cv2.cvtColor(images[i], cv2.COLOR_RGB2BGR) + new_imgs.append(model.test_transforms(images[i])[0]) + new_imgs = np.array(new_imgs) - result = model.exe.run( - model.test_prog, - feed={'image': new_imgs}, - fetch_list=list(model.interpretation_feats.values())) - return result \ No newline at end of file + out = model.exe.run(model.test_prog, + feed={'image': new_imgs}, + fetch_list=list(model.interpretation_feats.values())) + + model.test_transforms.transforms = tmp_transforms + + return out diff --git a/paddlex/interpret/visualize.py b/paddlex/interpret/visualize.py index de8e9151b9417fd3307c74d7bb67767bed1845c7..f0158402b69ad3eb90aac6b11a134889fda6dc2b 100644 --- a/paddlex/interpret/visualize.py +++ b/paddlex/interpret/visualize.py @@ -20,79 +20,79 @@ import numpy as np import paddlex as pdx from .interpretation_predict import interpretation_predict from .core.interpretation import Interpretation -from .core.normlime_base import precompute_normlime_weights +from .core.normlime_base import precompute_global_classifier from .core._session_preparation import gen_user_home - -def lime(img_file, - model, - num_samples=3000, - batch_size=50, - save_dir='./'): - """使用LIME算法将模型预测结果的可解释性可视化。 - + + +def lime(img_file, model, num_samples=3000, batch_size=50, save_dir='./'): + """使用LIME算法将模型预测结果的可解释性可视化。 + LIME表示与模型无关的局部可解释性,可以解释任何模型。LIME的思想是以输入样本为中心, 在其附近的空间中进行随机采样,每个采样通过原模型得到新的输出,这样得到一系列的输入 和对应的输出,LIME用一个简单的、可解释的模型(比如线性回归模型)来拟合这个映射关系, - 得到每个输入维度的权重,以此来解释模型。 - + 得到每个输入维度的权重,以此来解释模型。 + 注意:LIME可解释性结果可视化目前只支持分类模型。 - + Args: img_file (str): 预测图像路径。 model (paddlex.cv.models): paddlex中的模型。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 batch_size (int): 预测数据batch大小,默认为50。 - save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 """ assert model.model_type == 'classifier', \ 'Now the interpretation visualize only be supported in classifier!' if model.status != 'Normal': - raise Exception('The interpretation only can deal with the Normal model') + raise Exception( + 'The interpretation only can deal with the Normal model') if not osp.exists(save_dir): os.makedirs(save_dir) - model.arrange_transforms( - transforms=model.test_transforms, mode='test') + model.arrange_transforms(transforms=model.test_transforms, mode='test') tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms.transforms = tmp_transforms.transforms[:-2] img = tmp_transforms(img_file)[0] img = np.around(img).astype('uint8') img = np.expand_dims(img, axis=0) interpreter = None - interpreter = get_lime_interpreter(img, model, num_samples=num_samples, batch_size=batch_size) + interpreter = get_lime_interpreter( + img, model, num_samples=num_samples, batch_size=batch_size) img_name = osp.splitext(osp.split(img_file)[-1])[0] - interpreter.interpret(img, save_dir=save_dir) - - -def normlime(img_file, - model, - dataset=None, - num_samples=3000, - batch_size=50, - save_dir='./'): + interpreter.interpret(img, save_dir=osp.join(save_dir, img_name)) + + +def normlime(img_file, + model, + dataset=None, + num_samples=3000, + batch_size=50, + save_dir='./', + normlime_weights_file=None): """使用NormLIME算法将模型预测结果的可解释性可视化。 - + NormLIME是利用一定数量的样本来出一个全局的解释。NormLIME会提前计算一定数量的测 试样本的LIME结果,然后对相同的特征进行权重的归一化,这样来得到一个全局的输入和输出的关系。 - + 注意1:dataset读取的是一个数据集,该数据集不宜过大,否则计算时间会较长,但应包含所有类别的数据。 注意2:NormLIME可解释性结果可视化目前只支持分类模型。 - + Args: img_file (str): 预测图像路径。 model (paddlex.cv.models): paddlex中的模型。 dataset (paddlex.datasets): 数据集读取器,默认为None。 num_samples (int): LIME用于学习线性模型的采样数,默认为3000。 batch_size (int): 预测数据batch大小,默认为50。 - save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + save_dir (str): 可解释性可视化结果(保存为png格式文件)和中间文件存储路径。 + normlime_weights_file (str): NormLIME初始化文件名,若不存在,则计算一次,保存于该路径;若存在,则直接载入。 """ assert model.model_type == 'classifier', \ 'Now the interpretation visualize only be supported in classifier!' if model.status != 'Normal': - raise Exception('The interpretation only can deal with the Normal model') + raise Exception( + 'The interpretation only can deal with the Normal model') if not osp.exists(save_dir): os.makedirs(save_dir) - model.arrange_transforms( - transforms=model.test_transforms, mode='test') + model.arrange_transforms(transforms=model.test_transforms, mode='test') tmp_transforms = copy.deepcopy(model.test_transforms) tmp_transforms.transforms = tmp_transforms.transforms[:-2] img = tmp_transforms(img_file)[0] @@ -100,52 +100,48 @@ def normlime(img_file, img = np.expand_dims(img, axis=0) interpreter = None if dataset is None: - raise Exception('The dataset is None. Cannot implement this kind of interpretation') - interpreter = get_normlime_interpreter(img, model, dataset, - num_samples=num_samples, batch_size=batch_size, - save_dir=save_dir) + raise Exception( + 'The dataset is None. Cannot implement this kind of interpretation') + interpreter = get_normlime_interpreter( + img, + model, + dataset, + num_samples=num_samples, + batch_size=batch_size, + save_dir=save_dir, + normlime_weights_file=normlime_weights_file) img_name = osp.splitext(osp.split(img_file)[-1])[0] - interpreter.interpret(img, save_dir=save_dir) - - + interpreter.interpret(img, save_dir=osp.join(save_dir, img_name)) + + def get_lime_interpreter(img, model, num_samples=3000, batch_size=50): def predict_func(image): - image = image.astype('float32') - for i in range(image.shape[0]): - image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) - tmp_transforms = copy.deepcopy(model.test_transforms.transforms) - model.test_transforms.transforms = model.test_transforms.transforms[-2:] out = interpretation_predict(model, image) - model.test_transforms.transforms = tmp_transforms return out[0] + labels_name = None if hasattr(model, 'labels'): labels_name = model.labels - interpreter = Interpretation('lime', - predict_func, - labels_name, - num_samples=num_samples, - batch_size=batch_size) + interpreter = Interpretation( + 'lime', + predict_func, + labels_name, + num_samples=num_samples, + batch_size=batch_size) return interpreter -def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=50, save_dir='./'): - def precompute_predict_func(image): - image = image.astype('float32') - tmp_transforms = copy.deepcopy(model.test_transforms.transforms) - model.test_transforms.transforms = model.test_transforms.transforms[-2:] - out = interpretation_predict(model, image) - model.test_transforms.transforms = tmp_transforms - return out[0] +def get_normlime_interpreter(img, + model, + dataset, + num_samples=3000, + batch_size=50, + save_dir='./', + normlime_weights_file=None): def predict_func(image): - image = image.astype('float32') - for i in range(image.shape[0]): - image[i] = cv2.cvtColor(image[i], cv2.COLOR_RGB2BGR) - tmp_transforms = copy.deepcopy(model.test_transforms.transforms) - model.test_transforms.transforms = model.test_transforms.transforms[-2:] out = interpretation_predict(model, image) - model.test_transforms.transforms = tmp_transforms return out[0] + labels_name = None if dataset is not None: labels_name = dataset.labels @@ -157,28 +153,29 @@ def get_normlime_interpreter(img, model, dataset, num_samples=3000, batch_size=5 os.makedirs(root_path) url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz" pdx.utils.download_and_decompress(url, path=root_path) - npy_dir = precompute_for_normlime(precompute_predict_func, - dataset, - num_samples=num_samples, - batch_size=batch_size, - save_dir=save_dir) - interpreter = Interpretation('normlime', - predict_func, - labels_name, - num_samples=num_samples, - batch_size=batch_size, - normlime_weights=npy_dir) - return interpreter - -def precompute_for_normlime(predict_func, dataset, num_samples=3000, batch_size=50, save_dir='./'): - image_list = [] - for item in dataset.file_list: - image_list.append(item[0]) - return precompute_normlime_weights( - image_list, + if osp.exists(osp.join(save_dir, normlime_weights_file)): + normlime_weights_file = osp.join(save_dir, normlime_weights_file) + try: + np.load(normlime_weights_file, allow_pickle=True).item() + except: + normlime_weights_file = precompute_global_classifier( + dataset, + predict_func, + save_path=normlime_weights_file, + batch_size=batch_size) + else: + normlime_weights_file = precompute_global_classifier( + dataset, predict_func, - num_samples=num_samples, - batch_size=batch_size, - save_dir=save_dir) - + save_path=osp.join(save_dir, normlime_weights_file), + batch_size=batch_size) + + interpreter = Interpretation( + 'normlime', + predict_func, + labels_name, + num_samples=num_samples, + batch_size=batch_size, + normlime_weights=normlime_weights_file) + return interpreter diff --git a/setup.py b/setup.py index db62ca5e9e8107f2f32e804a0e92fb48766d3c27..44aca0f9dc2a214ff4bcf4e2817d06423c26812b 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ long_description = "PaddleX. A end-to-end deeplearning model development toolkit setuptools.setup( name="paddlex", - version='1.0.5', + version='1.0.6', author="paddlex", author_email="paddlex@baidu.com", description=long_description, diff --git a/tutorials/interpret/normlime.py b/tutorials/interpret/normlime.py index 3e501388e44aeab8548ae123831bc3211b08cea7..f3a1129780ab87d6d242010a124760c9a64608bd 100644 --- a/tutorials/interpret/normlime.py +++ b/tutorials/interpret/normlime.py @@ -14,18 +14,22 @@ model_file = 'https://bj.bcebos.com/paddlex/interpret/mini_imagenet_veg_mobilene pdx.utils.download_and_decompress(model_file, path='./') # 加载模型 -model = pdx.load_model('mini_imagenet_veg_mobilenetv2') +model_file = 'mini_imagenet_veg_mobilenetv2' +model = pdx.load_model(model_file) # 定义测试所用的数据集 +dataset = 'mini_imagenet_veg' test_dataset = pdx.datasets.ImageNet( - data_dir='mini_imagenet_veg', - file_list=osp.join('mini_imagenet_veg', 'test_list.txt'), - label_list=osp.join('mini_imagenet_veg', 'labels.txt'), + data_dir=dataset, + file_list=osp.join(dataset, 'test_list.txt'), + label_list=osp.join(dataset, 'labels.txt'), transforms=model.test_transforms) # 可解释性可视化 pdx.interpret.normlime( - 'mini_imagenet_veg/mushroom/n07734744_1106.JPEG', - model, - test_dataset, - save_dir='./') + test_dataset.file_list[0][0], + model, + test_dataset, + save_dir='./', + normlime_weights_file='{}_{}.npy'.format( + dataset.split('/')[-1], model.model_name))