提交 baa0036f 编写于 作者: S sunyanfang01

move sklearn

上级 441634a1
...@@ -242,7 +242,13 @@ class NormLIME(object): ...@@ -242,7 +242,13 @@ class NormLIME(object):
self.label_names = label_names self.label_names = label_names
def predict_cluster_labels(self, feature_map, segments): def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments)) X = get_feature_for_kmeans(feature_map, segments)
try:
cluster_labels = self.kmeans_model.predict(X)
except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_)
return cluster_labels
def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels): def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
# global weights # global weights
......
...@@ -30,17 +30,10 @@ The code in this file (lime_base.py) is modified from https://github.com/marcotc ...@@ -30,17 +30,10 @@ The code in this file (lime_base.py) is modified from https://github.com/marcotc
import numpy as np import numpy as np
import scipy as sp import scipy as sp
import sklearn
import sklearn.preprocessing
from skimage.color import gray2rgb
from sklearn.linear_model import Ridge, lars_path
from sklearn.utils import check_random_state
import tqdm import tqdm
import copy import copy
from functools import partial from functools import partial
from skimage.segmentation import quickshift
from skimage.measure import regionprops
class LimeBase(object): class LimeBase(object):
...@@ -59,6 +52,7 @@ class LimeBase(object): ...@@ -59,6 +52,7 @@ class LimeBase(object):
generate random numbers. If None, the random state will be generate random numbers. If None, the random state will be
initialized using the internal numpy seed. initialized using the internal numpy seed.
""" """
from sklearn.utils import check_random_state
self.kernel_fn = kernel_fn self.kernel_fn = kernel_fn
self.verbose = verbose self.verbose = verbose
self.random_state = check_random_state(random_state) self.random_state = check_random_state(random_state)
...@@ -75,6 +69,7 @@ class LimeBase(object): ...@@ -75,6 +69,7 @@ class LimeBase(object):
(alphas, coefs), both are arrays corresponding to the (alphas, coefs), both are arrays corresponding to the
regularization parameter and coefficients, respectively regularization parameter and coefficients, respectively
""" """
from sklearn.linear_model import lars_path
x_vector = weighted_data x_vector = weighted_data
alphas, _, coefs = lars_path(x_vector, alphas, _, coefs = lars_path(x_vector,
weighted_labels, weighted_labels,
...@@ -106,6 +101,7 @@ class LimeBase(object): ...@@ -106,6 +101,7 @@ class LimeBase(object):
def feature_selection(self, data, labels, weights, num_features, method): def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see interpret_instance_with_data to """Selects features for the model. see interpret_instance_with_data to
understand the parameters.""" understand the parameters."""
from sklearn.linear_model import Ridge
if method == 'none': if method == 'none':
return np.array(range(data.shape[1])) return np.array(range(data.shape[1]))
elif method == 'forward_selection': elif method == 'forward_selection':
...@@ -213,7 +209,7 @@ class LimeBase(object): ...@@ -213,7 +209,7 @@ class LimeBase(object):
score is the R^2 value of the returned interpretation score is the R^2 value of the returned interpretation
local_pred is the prediction of the interpretation model on the original instance local_pred is the prediction of the interpretation model on the original instance
""" """
from sklearn.linear_model import Ridge
weights = self.kernel_fn(distances) weights = self.kernel_fn(distances)
labels_column = neighborhood_labels[:, label] labels_column = neighborhood_labels[:, label]
used_features = self.feature_selection(neighborhood_data, used_features = self.feature_selection(neighborhood_data,
...@@ -376,6 +372,7 @@ class LimeImageInterpreter(object): ...@@ -376,6 +372,7 @@ class LimeImageInterpreter(object):
generate random numbers. If None, the random state will be generate random numbers. If None, the random state will be
initialized using the internal numpy seed. initialized using the internal numpy seed.
""" """
from sklearn.utils import check_random_state
kernel_width = float(kernel_width) kernel_width = float(kernel_width)
if kernel is None: if kernel is None:
...@@ -422,6 +419,10 @@ class LimeImageInterpreter(object): ...@@ -422,6 +419,10 @@ class LimeImageInterpreter(object):
An ImageIinterpretation object (see lime_image.py) with the corresponding An ImageIinterpretation object (see lime_image.py) with the corresponding
interpretations. interpretations.
""" """
import sklearn
from skimage.measure import regionprops
from skimage.segmentation import quickshift
from skimage.color import gray2rgb
if len(image.shape) == 2: if len(image.shape) == 2:
image = gray2rgb(image) image = gray2rgb(image)
......
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import glob import glob
from paddlex.interpret.as_data_reader.readers import read_image from paddlex.interpret.as_data_reader.readers import read_image
import paddlex.utils.logging as logging
from . import lime_base from . import lime_base
from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans
...@@ -113,11 +114,11 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -113,11 +114,11 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
save_path = os.path.join(save_dir, save_path) save_path = os.path.join(save_dir, save_path)
if os.path.exists(save_path): if os.path.exists(save_path):
print(f'{save_path} exists, not computing this one.') logging.info(save_path + ' exists, not computing this one.', use_color=True)
continue continue
print('processing', each_data_ if isinstance(each_data_, str) else data_index, logging.info('processing'+each_data_ if isinstance(each_data_, str) else data_index + \
f', {data_index}/{len(list_data_)}') f'+{data_index}/{len(list_data_)}', use_color=True)
image_show = read_image(each_data_) image_show = read_image(each_data_)
result = predict_fn(image_show) result = predict_fn(image_show)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册