提交 7fcfbf80 编写于 作者: S sunyanfang01

remove download

上级 441634a1
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import visualize from . import visualize
visualize = visualize.visualize lime = visualize.lime
\ No newline at end of file normlime = visualize.normlime
...@@ -28,17 +28,6 @@ def gen_user_home(): ...@@ -28,17 +28,6 @@ def gen_user_home():
return os.path.expanduser('~') return os.path.expanduser('~')
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
def paddle_get_fc_weights(var_name="fc_0.w_0"): def paddle_get_fc_weights(var_name="fc_0.w_0"):
fc_weights = fluid.global_scope().find_var(var_name).get_tensor() fc_weights = fluid.global_scope().find_var(var_name).get_tensor()
return np.array(fc_weights) return np.array(fc_weights)
...@@ -50,6 +39,14 @@ def paddle_resize(extracted_features, outsize): ...@@ -50,6 +39,14 @@ def paddle_resize(extracted_features, outsize):
def compute_features_for_kmeans(data_content): def compute_features_for_kmeans(data_content):
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
def conv_bn_layer(input, def conv_bn_layer(input,
num_filters, num_filters,
filter_size, filter_size,
......
...@@ -13,11 +13,12 @@ ...@@ -13,11 +13,12 @@
#limitations under the License. #limitations under the License.
import os import os
import os.path as osp
import numpy as np import numpy as np
import time import time
from . import lime_base from . import lime_base
from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, h_pre_models_kmeans from ._session_preparation import paddle_get_fc_weights, compute_features_for_kmeans, gen_user_home
from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model from .normlime_base import combine_normlime_and_lime, get_feature_for_kmeans, load_kmeans_model
from paddlex.interpret.as_data_reader.readers import read_image from paddlex.interpret.as_data_reader.readers import read_image
...@@ -215,6 +216,15 @@ class LIME(object): ...@@ -215,6 +216,15 @@ class LIME(object):
class NormLIME(object): class NormLIME(object):
def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50, def __init__(self, predict_fn, label_names, num_samples=3000, batch_size=50,
kmeans_model_for_normlime=None, normlime_weights=None): kmeans_model_for_normlime=None, normlime_weights=None):
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
if kmeans_model_for_normlime is None: if kmeans_model_for_normlime is None:
try: try:
self.kmeans_model = load_kmeans_model(h_pre_models_kmeans) self.kmeans_model = load_kmeans_model(h_pre_models_kmeans)
...@@ -242,7 +252,13 @@ class NormLIME(object): ...@@ -242,7 +252,13 @@ class NormLIME(object):
self.label_names = label_names self.label_names = label_names
def predict_cluster_labels(self, feature_map, segments): def predict_cluster_labels(self, feature_map, segments):
return self.kmeans_model.predict(get_feature_for_kmeans(feature_map, segments)) X = get_feature_for_kmeans(feature_map, segments)
try:
cluster_labels = self.kmeans_model.predict(X)
except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(X, self.kmeans_model.cluster_centers_)
return cluster_labels
def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels): def predict_using_normlime_weights(self, pred_labels, predicted_cluster_labels):
# global weights # global weights
......
...@@ -13,12 +13,14 @@ ...@@ -13,12 +13,14 @@
#limitations under the License. #limitations under the License.
import os import os
import os.path as osp
import numpy as np import numpy as np
import glob import glob
from paddlex.interpret.as_data_reader.readers import read_image from paddlex.interpret.as_data_reader.readers import read_image
import paddlex.utils.logging as logging
from . import lime_base from . import lime_base
from ._session_preparation import compute_features_for_kmeans, h_pre_models_kmeans from ._session_preparation import compute_features_for_kmeans, gen_user_home
def load_kmeans_model(fname): def load_kmeans_model(fname):
...@@ -102,6 +104,15 @@ def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels, ...@@ -102,6 +104,15 @@ def save_one_lime_predict_and_kmean_labels(lime_all_weights, image_pred_labels,
def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir): def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, save_dir):
root_path = gen_user_home()
root_path = osp.join(root_path, '.paddlex')
h_pre_models = osp.join(root_path, "pre_models")
if not osp.exists(h_pre_models):
if not osp.exists(root_path):
os.makedirs(root_path)
url = "https://bj.bcebos.com/paddlex/interpret/pre_models.tar.gz"
pdx.utils.download_and_decompress(url, path=root_path)
h_pre_models_kmeans = osp.join(h_pre_models, "kmeans_model.pkl")
kmeans_model = load_kmeans_model(h_pre_models_kmeans) kmeans_model = load_kmeans_model(h_pre_models_kmeans)
for data_index, each_data_ in enumerate(list_data_): for data_index, each_data_ in enumerate(list_data_):
...@@ -113,11 +124,10 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -113,11 +124,10 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
save_path = os.path.join(save_dir, save_path) save_path = os.path.join(save_dir, save_path)
if os.path.exists(save_path): if os.path.exists(save_path):
print(f'{save_path} exists, not computing this one.') logging.info(save_path + ' exists, not computing this one.', use_color=True)
continue continue
img_file_name = each_data_ if isinstance(each_data_, str) else data_index
print('processing', each_data_ if isinstance(each_data_, str) else data_index, logging.info('processing '+ img_file_name + ' [{}/{}]'.format(data_index, len(list_data_)), use_color=True)
f', {data_index}/{len(list_data_)}')
image_show = read_image(each_data_) image_show = read_image(each_data_)
result = predict_fn(image_show) result = predict_fn(image_show)
...@@ -149,9 +159,12 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav ...@@ -149,9 +159,12 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav
interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0, interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0,
num_samples=num_samples, batch_size=batch_size) num_samples=num_samples, batch_size=batch_size)
cluster_labels = kmeans_model.predict( X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments)
get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments) try:
) cluster_labels = kmeans_model.predict(X)
except AttributeError:
from sklearn.metrics import pairwise_distances_argmin_min
cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_)
save_one_lime_predict_and_kmean_labels( save_one_lime_predict_and_kmean_labels(
interpreter.local_weights, pred_label, interpreter.local_weights, pred_label,
cluster_labels, cluster_labels,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册