From 3aff3f2fcf063034d86443d07adf848474042f24 Mon Sep 17 00:00:00 2001 From: sunyanfang01 Date: Wed, 20 May 2020 10:06:43 +0800 Subject: [PATCH] for sklrearn 0.23 --- paddlex/interpret/core/normlime_base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddlex/interpret/core/normlime_base.py b/paddlex/interpret/core/normlime_base.py index 288f3a8..6fdd259 100644 --- a/paddlex/interpret/core/normlime_base.py +++ b/paddlex/interpret/core/normlime_base.py @@ -150,9 +150,12 @@ def precompute_lime_weights(list_data_, predict_fn, num_samples, batch_size, sav interpreter = algo.interpret_instance(image_show[0], predict_fn, pred_label, 0, num_samples=num_samples, batch_size=batch_size) - cluster_labels = kmeans_model.predict( - get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments) - ) + X = get_feature_for_kmeans(compute_features_for_kmeans(image_show).transpose((1, 2, 0)), interpreter.segments) + try: + cluster_labels = kmeans_model.predict(X) + except AttributeError: + from sklearn.metrics import pairwise_distances_argmin_min + cluster_labels, _ = pairwise_distances_argmin_min(X, kmeans_model.cluster_centers_) save_one_lime_predict_and_kmean_labels( interpreter.local_weights, pred_label, cluster_labels, -- GitLab