提交 baa0036f 编写于 作者: S sunyanfang01

move sklearn

上级 441634a1
......@@ -242,7 +242,13 @@ class NormLIME(object):
self.label_names = label_names
def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments))
X = get_feature_for_kmeans(feature_map, segments)
try:
cluster_labels = self.kmeans_model.predict(X)
except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_)
return cluster_labels
def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
# global weights
......
......@@ -30,17 +30,10 @@ The code in this file (lime_base.py) is modified from https://github.com/marcotc
import numpy as np
import scipy as sp
import sklearn
import sklearn.preprocessing
from skimage.color import gray2rgb
from sklearn.linear_model import Ridge, lars_path
from sklearn.utils import check_random_state
import tqdm
import copy
from functools import partial
from skimage.segmentation import quickshift
from skimage.measure import regionprops
class LimeBase(object):
......@@ -59,6 +52,7 @@ class LimeBase(object):
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
from sklearn.utils import check_random_state
self.kernel_fn = kernel_fn
self.verbose = verbose
self.random_state = check_random_state(random_state)
......@@ -75,6 +69,7 @@ class LimeBase(object):
(alphas, coefs), both are arrays corresponding to the
regularization parameter and coefficients, respectively
"""
from sklearn.linear_model import lars_path
x_vector = weighted_data
alphas, _, coefs = lars_path(x_vector,
weighted_labels,
......@@ -106,6 +101,7 @@ class LimeBase(object):
def feature_selection(self, data, labels, weights, num_features, method):
"""Selects features for the model. see interpret_instance_with_data to
understand the parameters."""
from sklearn.linear_model import Ridge
if method == 'none':
return np.array(range(data.shape[1]))
elif method == 'forward_selection':
......@@ -213,7 +209,7 @@ class LimeBase(object):
score is the R^2 value of the returned interpretation
local_pred is the prediction of the interpretation model on the original instance
"""
from sklearn.linear_model import Ridge
weights = self.kernel_fn(distances)
labels_column = neighborhood_labels[:, label]
used_features = self.feature_selection(neighborhood_data,
......@@ -376,6 +372,7 @@ class LimeImageInterpreter(object):
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
from sklearn.utils import check_random_state
kernel_width = float(kernel_width)
if kernel is None:
......@@ -422,6 +419,10 @@ class LimeImageInterpreter(object):
An ImageIinterpretation object (see lime_image.py) with the corresponding
interpretations.
"""
import sklearn
from skimage.measure import regionprops
from skimage.segmentation import quickshift
from skimage.color import gray2rgb
if len(image.shape) == 2:
image = gray2rgb(image)
......
......@@ -17,6 +17,7 @@ import numpy as np
import glob
from paddlex.interpret.as_data_reader.readers import read_image
import paddlex.utils.logging as logging
from . import lime_base
from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans
......@@ -113,11 +114,11 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
save_path = os.path.join(save_dir, save_path)
if os.path.exists(save_path):
print(f'{save_path} exists, not computing this one.')
logging.info(save_path + ' exists, not computing this one.', use_color=True)
continue
print('processing', each_data_ if isinstance(each_data_, str) else data_index,
f', {data_index}/{len(list_data_)}')
logging.info('processing'+each_data_ if isinstance(each_data_, str) else data_index + \
f'+{data_index}/{len(list_data_)}', use_color=True)
image_show = read_image(each_data_)
result = predict_fn(image_show)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册