Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
011cff21
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“10ec329b7d8613c60d7324395ecc42e10b3ce0c0”上不存在“paddle/capi/git@gitcode.net:paddlepaddle/Paddle.git”
提交
011cff21
编写于
5月 14, 2020
作者:
S
sunyanfang01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lime
上级
2484756a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1314 addition
and
11 deletion
+1314
-11
paddlex/cv/models/classifier.py
paddlex/cv/models/classifier.py
+16
-10
paddlex/cv/models/explanation/as_data_reader/data_path_utils.py
...x/cv/models/explanation/as_data_reader/data_path_utils.py
+27
-0
paddlex/cv/models/explanation/as_data_reader/readers.py
paddlex/cv/models/explanation/as_data_reader/readers.py
+211
-0
paddlex/cv/models/explanation/core/_session_preparation.py
paddlex/cv/models/explanation/core/_session_preparation.py
+13
-0
paddlex/cv/models/explanation/core/explanation.py
paddlex/cv/models/explanation/core/explanation.py
+37
-0
paddlex/cv/models/explanation/core/explanation_algorithms.py
paddlex/cv/models/explanation/core/explanation_algorithms.py
+458
-0
paddlex/cv/models/explanation/core/lime_base.py
paddlex/cv/models/explanation/core/lime_base.py
+502
-0
paddlex/cv/models/explanation/visualize.py
paddlex/cv/models/explanation/visualize.py
+46
-0
paddlex/cv/nets/resnet.py
paddlex/cv/nets/resnet.py
+4
-1
未找到文件。
paddlex/cv/models/classifier.py
浏览文件 @
011cff21
...
@@ -27,7 +27,6 @@ from .base import BaseAPI
...
@@ -27,7 +27,6 @@ from .base import BaseAPI
class
BaseClassifier
(
BaseAPI
):
class
BaseClassifier
(
BaseAPI
):
"""构建分类器,并实现其训练、评估、预测和模型导出。
"""构建分类器,并实现其训练、评估、预测和模型导出。
Args:
Args:
model_name (str): 分类器的模型名字,取值范围为['ResNet18',
model_name (str): 分类器的模型名字,取值范围为['ResNet18',
'ResNet34', 'ResNet50', 'ResNet101',
'ResNet34', 'ResNet50', 'ResNet101',
...
@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI):
...
@@ -61,10 +60,10 @@ class BaseClassifier(BaseAPI):
if
mode
!=
'test'
:
if
mode
!=
'test'
:
label
=
fluid
.
data
(
dtype
=
'int64'
,
shape
=
[
None
,
1
],
name
=
'label'
)
label
=
fluid
.
data
(
dtype
=
'int64'
,
shape
=
[
None
,
1
],
name
=
'label'
)
model
=
getattr
(
paddlex
.
cv
.
nets
,
str
.
lower
(
self
.
model_name
))
model
=
getattr
(
paddlex
.
cv
.
nets
,
str
.
lower
(
self
.
model_name
))
net_out
=
model
(
image
,
num_classes
=
self
.
num_classes
)
net_out
,
feat
=
model
(
image
,
num_classes
=
self
.
num_classes
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
softmax_out
=
fluid
.
layers
.
softmax
(
net_out
,
use_cudnn
=
False
)
inputs
=
OrderedDict
([(
'image'
,
image
)])
inputs
=
OrderedDict
([(
'image'
,
image
)])
outputs
=
OrderedDict
([(
'predict'
,
softmax_out
)])
outputs
=
OrderedDict
([(
'predict'
,
softmax_out
)
,
(
'net_out'
,
feat
[
-
1
])
])
if
mode
!=
'test'
:
if
mode
!=
'test'
:
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
label
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
...
@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI):
...
@@ -115,7 +114,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience
=
5
,
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
resume_checkpoint
=
None
):
"""训练。
"""训练。
Args:
Args:
num_epochs (int): 训练迭代轮数。
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
train_dataset (paddlex.datasets): 训练数据读取器。
...
@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI):
...
@@ -139,7 +137,6 @@ class BaseClassifier(BaseAPI):
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
Raises:
ValueError: 模型从inference model进行加载。
ValueError: 模型从inference model进行加载。
"""
"""
...
@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI):
...
@@ -183,13 +180,11 @@ class BaseClassifier(BaseAPI):
epoch_id
=
None
,
epoch_id
=
None
,
return_details
=
False
):
return_details
=
False
):
"""评估。
"""评估。
Args:
Args:
eval_dataset (paddlex.datasets): 验证数据读取器。
eval_dataset (paddlex.datasets): 验证数据读取器。
batch_size (int): 验证数据批大小。默认为1。
batch_size (int): 验证数据批大小。默认为1。
epoch_id (int): 当前评估模型所在的训练轮数。
epoch_id (int): 当前评估模型所在的训练轮数。
return_details (bool): 是否返回详细信息。
return_details (bool): 是否返回详细信息。
Returns:
Returns:
dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
dict: 当return_details为False时,返回dict, 包含关键字:'acc1'、'acc5',
分别表示最大值的accuracy、前5个最大值的accuracy。
分别表示最大值的accuracy、前5个最大值的accuracy。
...
@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI):
...
@@ -248,12 +243,10 @@ class BaseClassifier(BaseAPI):
def
predict
(
self
,
img_file
,
transforms
=
None
,
topk
=
1
):
def
predict
(
self
,
img_file
,
transforms
=
None
,
topk
=
1
):
"""预测。
"""预测。
Args:
Args:
img_file (str): 预测图像路径。
img_file (str): 预测图像路径。
transforms (paddlex.cls.transforms): 数据预处理操作。
transforms (paddlex.cls.transforms): 数据预处理操作。
topk (int): 预测时前k个最大值。
topk (int): 预测时前k个最大值。
Returns:
Returns:
list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
list: 其中元素均为字典。字典的关键字为'category_id'、'category'、'score',
分别对应预测类别id、预测类别标签、预测得分。
分别对应预测类别id、预测类别标签、预测得分。
...
@@ -279,7 +272,20 @@ class BaseClassifier(BaseAPI):
...
@@ -279,7 +272,20 @@ class BaseClassifier(BaseAPI):
'score'
:
result
[
0
][
0
][
l
]
'score'
:
result
[
0
][
0
][
l
]
}
for
l
in
pred_label
]
}
for
l
in
pred_label
]
return
res
return
res
def
explanation_predict
(
self
,
images
):
self
.
arrange_transforms
(
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
new_imgs
=
[]
for
i
in
range
(
images
.
shape
[
0
]):
img
=
images
[
i
]
new_imgs
.
append
(
self
.
test_transforms
(
img
)[
0
])
new_imgs
=
np
.
array
(
new_imgs
)
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
new_imgs
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()))
return
result
[
1
:]
class
ResNet18
(
BaseClassifier
):
class
ResNet18
(
BaseClassifier
):
def
__init__
(
self
,
num_classes
=
1000
):
def
__init__
(
self
,
num_classes
=
1000
):
...
...
paddlex/cv/models/explanation/as_data_reader/data_path_utils.py
0 → 100644
浏览文件 @
011cff21
import
os
def
imagenet_val_files_and_labels
(
dataset_directory
):
classes
=
open
(
os
.
path
.
join
(
dataset_directory
,
'imagenet_lsvrc_2015_synsets.txt'
)).
readlines
()
class_to_indx
=
{
classes
[
i
].
split
(
'
\n
'
)[
0
]:
i
for
i
in
range
(
len
(
classes
))}
images_path
=
os
.
path
.
join
(
dataset_directory
,
'val'
)
filenames
=
[]
labels
=
[]
lines
=
open
(
os
.
path
.
join
(
dataset_directory
,
'imagenet_2012_validation_synset_labels.txt'
),
'r'
).
readlines
()
for
i
,
line
in
enumerate
(
lines
):
class_name
=
line
.
split
(
'
\n
'
)[
0
]
a
=
'ILSVRC2012_val_%08d.JPEG'
%
(
i
+
1
)
filenames
.
append
(
f
'
{
images_path
}
/
{
a
}
'
)
labels
.
append
(
class_to_indx
[
class_name
])
# print(filenames[-1], labels[-1])
return
filenames
,
labels
def
_find_classes
(
dir
):
# Faster and available in Python 3.5 and above
classes
=
[
d
.
name
for
d
in
os
.
scandir
(
dir
)
if
d
.
is_dir
()]
classes
.
sort
()
class_to_idx
=
{
classes
[
i
]:
i
for
i
in
range
(
len
(
classes
))}
return
classes
,
class_to_idx
\ No newline at end of file
paddlex/cv/models/explanation/as_data_reader/readers.py
0 → 100644
浏览文件 @
011cff21
import
os
import
sys
;
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
)))
import
cv2
import
numpy
as
np
import
six
import
glob
from
as_data_reader.data_path_utils
import
_find_classes
from
PIL
import
Image
def
resize_short
(
img
,
target_size
,
interpolation
=
None
):
"""resize image
Args:
img: image data
target_size: resize short target size
interpolation: interpolation mode
Returns:
resized image data
"""
percent
=
float
(
target_size
)
/
min
(
img
.
shape
[
0
],
img
.
shape
[
1
])
resized_width
=
int
(
round
(
img
.
shape
[
1
]
*
percent
))
resized_height
=
int
(
round
(
img
.
shape
[
0
]
*
percent
))
if
interpolation
:
resized
=
cv2
.
resize
(
img
,
(
resized_width
,
resized_height
),
interpolation
=
interpolation
)
else
:
resized
=
cv2
.
resize
(
img
,
(
resized_width
,
resized_height
))
return
resized
def
crop_image
(
img
,
target_size
,
center
=
True
):
"""crop image
Args:
img: images data
target_size: crop target size
center: crop mode
Returns:
img: cropped image data
"""
height
,
width
=
img
.
shape
[:
2
]
size
=
target_size
if
center
:
w_start
=
(
width
-
size
)
//
2
h_start
=
(
height
-
size
)
//
2
else
:
w_start
=
np
.
random
.
randint
(
0
,
width
-
size
+
1
)
h_start
=
np
.
random
.
randint
(
0
,
height
-
size
+
1
)
w_end
=
w_start
+
size
h_end
=
h_start
+
size
img
=
img
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
return
img
def
preprocess_image
(
img
,
random_mirror
=
False
):
"""
centered, scaled by 1/255.
:param img: np.array: shape: [ns, h, w, 3], color order: rgb.
:return: np.array: shape: [ns, h, w, 3]
"""
mean
=
[
0.485
,
0.456
,
0.406
]
std
=
[
0.229
,
0.224
,
0.225
]
# transpose to [ns, 3, h, w]
img
=
img
.
astype
(
'float32'
).
transpose
((
0
,
3
,
1
,
2
))
/
255
img_mean
=
np
.
array
(
mean
).
reshape
((
3
,
1
,
1
))
img_std
=
np
.
array
(
std
).
reshape
((
3
,
1
,
1
))
img
-=
img_mean
img
/=
img_std
if
random_mirror
:
mirror
=
int
(
np
.
random
.
uniform
(
0
,
2
))
if
mirror
==
1
:
img
=
img
[:,
:,
::
-
1
,
:]
return
img
def
read_image
(
img_path
,
target_size
=
256
,
crop_size
=
224
):
"""
resize_short to 256, then center crop to 224.
:param img_path: one image path
:return: np.array: shape: [1, h, w, 3], color order: rgb.
"""
if
isinstance
(
img_path
,
str
):
with
open
(
img_path
,
'rb'
)
as
f
:
img
=
Image
.
open
(
f
)
img
=
img
.
convert
(
'RGB'
)
img
=
np
.
array
(
img
)
# img = cv2.imread(img_path)
img
=
resize_short
(
img
,
target_size
,
interpolation
=
None
)
img
=
crop_image
(
img
,
target_size
=
crop_size
,
center
=
True
)
# img = img[:, :, ::-1]
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
return
img
elif
isinstance
(
img_path
,
np
.
ndarray
):
assert
len
(
img_path
.
shape
)
==
4
return
img_path
else
:
ValueError
(
f
"Not recognized data type
{
type
(
img_path
)
}
."
)
class
ReaderConfig
(
object
):
"""
A generic data loader where the images are arranged in this way:
root/train/dog/xxy.jpg
root/train/dog/xxz.jpg
...
root/train/cat/nsdf3.jpg
root/train/cat/asd932_.jpg
...
root/test/dog/xxx.jpg
...
root/test/cat/123.jpg
...
"""
def
__init__
(
self
,
dataset_dir
,
is_test
):
image_paths
,
labels
,
self
.
num_classes
=
self
.
get_dataset_info
(
dataset_dir
,
is_test
)
random_per
=
np
.
random
.
permutation
(
range
(
len
(
image_paths
)))
self
.
image_paths
=
image_paths
[
random_per
]
self
.
labels
=
labels
[
random_per
]
self
.
is_test
=
is_test
def
get_reader
(
self
):
def
reader
():
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
,
'.tiff'
,
'.webp'
)
target_size
=
256
crop_size
=
224
for
i
,
img_path
in
enumerate
(
self
.
image_paths
):
if
not
img_path
.
lower
().
endswith
(
IMG_EXTENSIONS
):
continue
img
=
cv2
.
imread
(
img_path
)
if
img
is
None
:
print
(
img_path
)
continue
img
=
resize_short
(
img
,
target_size
,
interpolation
=
None
)
img
=
crop_image
(
img
,
crop_size
,
center
=
self
.
is_test
)
img
=
img
[:,
:,
::
-
1
]
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
preprocess_image
(
img
,
not
self
.
is_test
)
yield
img
,
self
.
labels
[
i
]
return
reader
def
get_dataset_info
(
self
,
dataset_dir
,
is_test
=
False
):
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
,
'.tiff'
,
'.webp'
)
# read
if
is_test
:
datasubset_dir
=
os
.
path
.
join
(
dataset_dir
,
'test'
)
else
:
datasubset_dir
=
os
.
path
.
join
(
dataset_dir
,
'train'
)
class_names
,
class_to_idx
=
_find_classes
(
datasubset_dir
)
# num_classes = len(class_names)
image_paths
=
[]
labels
=
[]
for
class_name
in
class_names
:
classes_dir
=
os
.
path
.
join
(
datasubset_dir
,
class_name
)
for
img_path
in
glob
.
glob
(
os
.
path
.
join
(
classes_dir
,
'*'
)):
if
not
img_path
.
lower
().
endswith
(
IMG_EXTENSIONS
):
continue
image_paths
.
append
(
img_path
)
labels
.
append
(
class_to_idx
[
class_name
])
image_paths
=
np
.
array
(
image_paths
)
labels
=
np
.
array
(
labels
)
return
image_paths
,
labels
,
len
(
class_names
)
def
create_reader
(
list_image_path
,
list_label
=
None
,
is_test
=
False
):
def
reader
():
IMG_EXTENSIONS
=
(
'.jpg'
,
'.jpeg'
,
'.png'
,
'.ppm'
,
'.bmp'
,
'.pgm'
,
'.tif'
,
'.tiff'
,
'.webp'
)
target_size
=
256
crop_size
=
224
for
i
,
img_path
in
enumerate
(
list_image_path
):
if
not
img_path
.
lower
().
endswith
(
IMG_EXTENSIONS
):
continue
img
=
cv2
.
imread
(
img_path
)
if
img
is
None
:
print
(
img_path
)
continue
img
=
resize_short
(
img
,
target_size
,
interpolation
=
None
)
img
=
crop_image
(
img
,
crop_size
,
center
=
is_test
)
img
=
img
[:,
:,
::
-
1
]
img_show
=
np
.
expand_dims
(
img
,
axis
=
0
)
img
=
preprocess_image
(
img_show
,
not
is_test
)
label
=
0
if
list_label
is
None
else
list_label
[
i
]
yield
img_show
,
img
,
label
return
reader
\ No newline at end of file
paddlex/cv/models/explanation/core/_session_preparation.py
0 → 100644
浏览文件 @
011cff21
import
os
import
paddle.fluid
as
fluid
import
numpy
as
np
def
paddle_get_fc_weights
(
var_name
=
"fc_0.w_0"
):
fc_weights
=
fluid
.
global_scope
().
find_var
(
var_name
).
get_tensor
()
return
np
.
array
(
fc_weights
)
def
paddle_resize
(
extracted_features
,
outsize
):
resized_features
=
fluid
.
layers
.
resize_bilinear
(
extracted_features
,
outsize
)
return
resized_features
\ No newline at end of file
paddlex/cv/models/explanation/core/explanation.py
0 → 100644
浏览文件 @
011cff21
from
.explanation_algorithms
import
CAM
,
LIME
,
NormLIME
class
Explanation
(
object
):
"""
Base class for all explanation algorithms.
"""
def
__init__
(
self
,
explanation_algorithm_name
,
predict_fn
,
**
kwargs
):
supported_algorithms
=
{
'cam'
:
CAM
,
'lime'
:
LIME
,
'normlime'
:
NormLIME
}
self
.
algorithm_name
=
explanation_algorithm_name
.
lower
()
assert
self
.
algorithm_name
in
supported_algorithms
.
keys
()
self
.
predict_fn
=
predict_fn
# initialization for the explanation algorithm.
self
.
explain_algorithm
=
supported_algorithms
[
self
.
algorithm_name
](
self
.
predict_fn
,
**
kwargs
)
def
explain
(
self
,
data_
,
visualization
=
True
,
save_to_disk
=
True
,
save_dir
=
'./tmp'
):
"""
Args:
data_: data_ can be a path or numpy.ndarray.
visualization: whether to show using matplotlib.
save_to_disk: whether to save the figure in local disk.
save_dir: dir to save figure if save_to_disk is True.
Returns:
"""
return
self
.
explain_algorithm
.
explain
(
data_
,
visualization
,
save_to_disk
,
save_dir
)
paddlex/cv/models/explanation/core/explanation_algorithms.py
0 → 100644
浏览文件 @
011cff21
import
os
import
numpy
as
np
import
time
from
.
import
lime_base
from
..as_data_reader.readers
import
read_image
from
._session_preparation
import
paddle_get_fc_weights
import
cv2
class
CAM
(
object
):
def
__init__
(
self
,
predict_fn
):
"""
Args:
predict_fn: input: images_show [N, H, W, 3], RGB range(0, 255)
output: [
logits [N, num_classes],
feature map before global average pooling [N, num_channels, h_, w_]
]
"""
self
.
predict_fn
=
predict_fn
def
preparation_cam
(
self
,
data_path
):
image_show
=
read_image
(
data_path
)
result
=
self
.
predict_fn
(
image_show
)
logit
=
result
[
0
][
0
]
if
abs
(
np
.
sum
(
logit
)
-
1.0
)
>
1e-4
:
# softmax
exp_result
=
np
.
exp
(
logit
)
probability
=
exp_result
/
np
.
sum
(
exp_result
)
else
:
probability
=
logit
# only explain top 1
pred_label
=
np
.
argsort
(
probability
)
pred_label
=
pred_label
[
-
1
:]
self
.
predicted_label
=
pred_label
[
0
]
self
.
predicted_probability
=
probability
[
pred_label
[
0
]]
self
.
image
=
image_show
[
0
]
self
.
labels
=
pred_label
fc_weights
=
paddle_get_fc_weights
()
feature_maps
=
result
[
1
]
print
(
'predicted result: '
,
pred_label
[
0
],
probability
[
pred_label
[
0
]])
return
feature_maps
,
fc_weights
def
explain
(
self
,
data_
,
visualization
=
True
,
save_to_disk
=
True
,
save_outdir
=
None
):
feature_maps
,
fc_weights
=
self
.
preparation_cam
(
data_
)
cam
=
get_cam
(
self
.
image
,
feature_maps
,
fc_weights
,
self
.
predicted_label
)
if
visualization
or
save_to_disk
:
import
matplotlib.pyplot
as
plt
from
skimage.segmentation
import
mark_boundaries
l
=
self
.
labels
[
0
]
psize
=
5
nrows
=
1
ncols
=
2
plt
.
close
()
f
,
axes
=
plt
.
subplots
(
nrows
,
ncols
,
figsize
=
(
psize
*
ncols
,
psize
*
nrows
))
for
ax
in
axes
.
ravel
():
ax
.
axis
(
"off"
)
axes
=
axes
.
ravel
()
axes
[
0
].
imshow
(
self
.
image
)
axes
[
0
].
set_title
(
f
"label
{
l
}
, proba:
{
self
.
predicted_probability
:
.
3
f
}
"
)
axes
[
1
].
imshow
(
cam
)
axes
[
1
].
set_title
(
"CAM"
)
if
save_to_disk
and
save_outdir
is
not
None
:
os
.
makedirs
(
save_outdir
,
exist_ok
=
True
)
save_fig
(
data_
,
save_outdir
,
'cam'
)
if
visualization
:
plt
.
show
()
return
class
LIME
(
object
):
def
__init__
(
self
,
predict_fn
,
num_samples
=
3000
,
batch_size
=
50
):
"""
LIME wrapper. See lime_base.py for the detailed LIME implementation.
Args:
predict_fn: from image [N, H, W, 3] to logits [N, num_classes], this is necessary for computing LIME.
num_samples: the number of samples that LIME takes for fitting.
batch_size: batch size for model inference each time.
"""
self
.
num_samples
=
num_samples
self
.
batch_size
=
batch_size
self
.
predict_fn
=
predict_fn
self
.
labels
=
None
self
.
image
=
None
self
.
lime_explainer
=
None
def
preparation_lime
(
self
,
data_path
):
image_show
=
read_image
(
data_path
)
result
=
self
.
predict_fn
(
image_show
)
result
=
result
[
0
]
# only one image here.
if
abs
(
np
.
sum
(
result
)
-
1.0
)
>
1e-4
:
# softmax
exp_result
=
np
.
exp
(
result
)
probability
=
exp_result
/
np
.
sum
(
exp_result
)
else
:
probability
=
result
# only explain top 1
pred_label
=
np
.
argsort
(
probability
)
pred_label
=
pred_label
[
-
1
:]
self
.
predicted_label
=
pred_label
[
0
]
self
.
predicted_probability
=
probability
[
pred_label
[
0
]]
self
.
image
=
image_show
[
0
]
self
.
labels
=
pred_label
print
(
f
'predicted result:
{
pred_label
[
0
]
}
with probability
{
probability
[
pred_label
[
0
]]:
.
3
f
}
'
)
end
=
time
.
time
()
algo
=
lime_base
.
LimeImageExplainer
()
explainer
=
algo
.
explain_instance
(
self
.
image
,
self
.
predict_fn
,
self
.
labels
,
0
,
num_samples
=
self
.
num_samples
,
batch_size
=
self
.
batch_size
)
self
.
lime_explainer
=
explainer
print
(
'lime time: '
,
time
.
time
()
-
end
,
's.'
)
def
explain
(
self
,
data_
,
visualization
=
True
,
save_to_disk
=
True
,
save_outdir
=
None
):
if
self
.
lime_explainer
is
None
:
self
.
preparation_lime
(
data_
)
if
visualization
or
save_to_disk
:
import
matplotlib.pyplot
as
plt
from
skimage.segmentation
import
mark_boundaries
l
=
self
.
labels
[
0
]
psize
=
5
nrows
=
2
weights_choices
=
[
0.6
,
0.75
,
0.85
]
ncols
=
len
(
weights_choices
)
plt
.
close
()
f
,
axes
=
plt
.
subplots
(
nrows
,
ncols
,
figsize
=
(
psize
*
ncols
,
psize
*
nrows
))
for
ax
in
axes
.
ravel
():
ax
.
axis
(
"off"
)
axes
=
axes
.
ravel
()
axes
[
0
].
imshow
(
self
.
image
)
axes
[
0
].
set_title
(
f
"label
{
l
}
, proba:
{
self
.
predicted_probability
:
.
3
f
}
"
)
axes
[
1
].
imshow
(
mark_boundaries
(
self
.
image
,
self
.
lime_explainer
.
segments
))
axes
[
1
].
set_title
(
"superpixel segmentation"
)
# LIME visualization
for
i
,
w
in
enumerate
(
weights_choices
):
num_to_show
=
auto_choose_num_features_to_show
(
self
.
lime_explainer
,
l
,
w
)
temp
,
mask
=
self
.
lime_explainer
.
get_image_and_mask
(
l
,
positive_only
=
False
,
hide_rest
=
False
,
num_features
=
num_to_show
)
axes
[
ncols
+
i
].
imshow
(
mark_boundaries
(
temp
,
mask
))
axes
[
ncols
+
i
].
set_title
(
f
"label
{
l
}
, first
{
num_to_show
}
superpixels"
)
if
save_to_disk
and
save_outdir
is
not
None
:
os
.
makedirs
(
save_outdir
,
exist_ok
=
True
)
save_fig
(
data_
,
save_outdir
,
'lime'
,
self
.
num_samples
)
if
visualization
:
plt
.
show
()
return
class
NormLIME
(
object
):
def
__init__
(
self
,
predict_fn
,
num_samples
=
3000
,
batch_size
=
50
,
kmeans_model_for_normlime
=
None
,
normlime_weights
=
None
):
assert
kmeans_model_for_normlime
is
not
None
,
"NormLIME needs the KMeans model."
if
normlime_weights
is
None
:
raise
NotImplementedError
(
"Computing NormLIME weights is not implemented yet."
)
self
.
num_samples
=
num_samples
self
.
batch_size
=
batch_size
self
.
kmeans_model
=
load_kmeans_model
(
kmeans_model_for_normlime
)
self
.
normlime_weights
=
np
.
load
(
normlime_weights
,
allow_pickle
=
True
).
item
()
self
.
predict_fn
=
predict_fn
self
.
labels
=
None
self
.
image
=
None
def
predict_cluster_labels
(
self
,
feature_map
,
segments
):
return
self
.
kmeans_model
.
predict
(
get_feature_for_kmeans
(
feature_map
,
segments
))
def
predict_using_normlime_weights
(
self
,
pred_labels
,
predicted_cluster_labels
):
# global weights
g_weights
=
{
y
:
[]
for
y
in
pred_labels
}
for
y
in
pred_labels
:
cluster_weights_y
=
self
.
normlime_weights
[
y
]
g_weights
[
y
]
=
[
# some are not in the dict, 3000 samples may be not enough.
(
i
,
cluster_weights_y
.
get
(
k
,
0.0
))
for
i
,
k
in
enumerate
(
predicted_cluster_labels
)
]
g_weights
[
y
]
=
sorted
(
g_weights
[
y
],
key
=
lambda
x
:
np
.
abs
(
x
[
1
]),
reverse
=
True
)
return
g_weights
def
preparation_normlime
(
self
,
data_path
):
self
.
_lime
=
LIME
(
lambda
images
:
self
.
predict_fn
(
images
)[
0
],
self
.
num_samples
,
self
.
batch_size
)
self
.
_lime
.
preparation_lime
(
data_path
)
image_show
=
read_image
(
data_path
)
result
=
self
.
predict_fn
(
image_show
)
logit
=
result
[
0
][
0
]
# only one image here.
if
abs
(
np
.
sum
(
logit
)
-
1.0
)
>
1e-4
:
# softmax
exp_result
=
np
.
exp
(
logit
)
probability
=
exp_result
/
np
.
sum
(
exp_result
)
else
:
probability
=
logit
# only explain top 1
pred_label
=
np
.
argsort
(
probability
)
pred_label
=
pred_label
[
-
1
:]
self
.
predicted_label
=
pred_label
[
0
]
self
.
predicted_probability
=
probability
[
pred_label
[
0
]]
self
.
image
=
image_show
[
0
]
self
.
labels
=
pred_label
print
(
'predicted result: '
,
pred_label
[
0
],
probability
[
pred_label
[
0
]])
local_feature_map
=
result
[
1
][
0
]
cluster_labels
=
self
.
predict_cluster_labels
(
local_feature_map
.
transpose
((
1
,
2
,
0
)),
self
.
_lime
.
lime_explainer
.
segments
)
g_weights
=
self
.
predict_using_normlime_weights
(
self
.
labels
,
cluster_labels
)
return
g_weights
def
explain
(
self
,
data_
,
visualization
=
True
,
save_to_disk
=
True
,
save_outdir
=
None
):
g_weights
=
self
.
preparation_normlime
(
data_
)
lime_weights
=
self
.
_lime
.
lime_explainer
.
local_exp
if
visualization
or
save_to_disk
:
import
matplotlib.pyplot
as
plt
from
skimage.segmentation
import
mark_boundaries
l
=
self
.
labels
[
0
]
psize
=
5
nrows
=
4
weights_choices
=
[
0.6
,
0.85
,
0.99
]
ncols
=
len
(
weights_choices
)
plt
.
close
()
f
,
axes
=
plt
.
subplots
(
nrows
,
ncols
,
figsize
=
(
psize
*
ncols
,
psize
*
nrows
))
for
ax
in
axes
.
ravel
():
ax
.
axis
(
"off"
)
axes
=
axes
.
ravel
()
axes
[
0
].
imshow
(
self
.
image
)
axes
[
0
].
set_title
(
f
"label
{
l
}
, proba:
{
self
.
predicted_probability
:
.
3
f
}
"
)
axes
[
1
].
imshow
(
mark_boundaries
(
self
.
image
,
self
.
_lime
.
lime_explainer
.
segments
))
axes
[
1
].
set_title
(
"superpixel segmentation"
)
# LIME visualization
for
i
,
w
in
enumerate
(
weights_choices
):
num_to_show
=
auto_choose_num_features_to_show
(
self
.
_lime
.
lime_explainer
,
l
,
w
)
temp
,
mask
=
self
.
_lime
.
lime_explainer
.
get_image_and_mask
(
l
,
positive_only
=
False
,
hide_rest
=
False
,
num_features
=
num_to_show
)
axes
[
ncols
+
i
].
imshow
(
mark_boundaries
(
temp
,
mask
))
axes
[
ncols
+
i
].
set_title
(
f
"label
{
l
}
, first
{
num_to_show
}
superpixels"
)
# NormLIME visualization
self
.
_lime
.
lime_explainer
.
local_exp
=
g_weights
for
i
,
w
in
enumerate
(
weights_choices
):
num_to_show
=
auto_choose_num_features_to_show
(
self
.
_lime
.
lime_explainer
,
l
,
w
)
temp
,
mask
=
self
.
_lime
.
lime_explainer
.
get_image_and_mask
(
l
,
positive_only
=
False
,
hide_rest
=
False
,
num_features
=
num_to_show
)
axes
[
ncols
*
2
+
i
].
imshow
(
mark_boundaries
(
temp
,
mask
))
axes
[
ncols
*
2
+
i
].
set_title
(
f
"label
{
l
}
, first
{
num_to_show
}
superpixels"
)
# NormLIME*LIME visualization
combined_weights
=
combine_normlime_and_lime
(
lime_weights
,
g_weights
)
self
.
_lime
.
lime_explainer
.
local_exp
=
combined_weights
for
i
,
w
in
enumerate
(
weights_choices
):
num_to_show
=
auto_choose_num_features_to_show
(
self
.
_lime
.
lime_explainer
,
l
,
w
)
temp
,
mask
=
self
.
_lime
.
lime_explainer
.
get_image_and_mask
(
l
,
positive_only
=
False
,
hide_rest
=
False
,
num_features
=
num_to_show
)
axes
[
ncols
*
3
+
i
].
imshow
(
mark_boundaries
(
temp
,
mask
))
axes
[
ncols
*
3
+
i
].
set_title
(
f
"label
{
l
}
, first
{
num_to_show
}
superpixels"
)
self
.
_lime
.
lime_explainer
.
local_exp
=
lime_weights
if
save_to_disk
and
save_outdir
is
not
None
:
os
.
makedirs
(
save_outdir
,
exist_ok
=
True
)
save_fig
(
data_
,
save_outdir
,
'normlime'
,
self
.
num_samples
)
if
visualization
:
plt
.
show
()
def
load_kmeans_model
(
fname
):
import
pickle
with
open
(
fname
,
'rb'
)
as
f
:
kmeans_model
=
pickle
.
load
(
f
)
return
kmeans_model
def
auto_choose_num_features_to_show
(
lime_explainer
,
label
,
percentage_to_show
):
segments
=
lime_explainer
.
segments
lime_weights
=
lime_explainer
.
local_exp
[
label
]
num_pixels_threshold_in_a_sp
=
segments
.
shape
[
0
]
*
segments
.
shape
[
1
]
//
len
(
np
.
unique
(
segments
))
//
8
# l1 norm with filtered weights.
used_weights
=
[(
tuple_w
[
0
],
tuple_w
[
1
])
for
i
,
tuple_w
in
enumerate
(
lime_weights
)
if
tuple_w
[
1
]
>
0
]
norm
=
np
.
sum
([
tuple_w
[
1
]
for
i
,
tuple_w
in
enumerate
(
used_weights
)])
normalized_weights
=
[(
tuple_w
[
0
],
tuple_w
[
1
]
/
norm
)
for
i
,
tuple_w
in
enumerate
(
lime_weights
)]
a
=
0.0
n
=
0
for
i
,
tuple_w
in
enumerate
(
normalized_weights
):
if
tuple_w
[
1
]
<
0
:
continue
if
len
(
np
.
where
(
segments
==
tuple_w
[
0
])[
0
])
<
num_pixels_threshold_in_a_sp
:
continue
a
+=
tuple_w
[
1
]
if
a
>
percentage_to_show
:
n
=
i
+
1
break
if
n
==
0
:
return
auto_choose_num_features_to_show
(
lime_explainer
,
label
,
percentage_to_show
-
0.1
)
return
n
def
get_cam
(
image_show
,
feature_maps
,
fc_weights
,
label_index
,
cam_min
=
None
,
cam_max
=
None
):
_
,
nc
,
h
,
w
=
feature_maps
.
shape
cam
=
feature_maps
*
fc_weights
[:,
label_index
].
reshape
(
1
,
nc
,
1
,
1
)
cam
=
cam
.
sum
((
0
,
1
))
if
cam_min
is
None
:
cam_min
=
np
.
min
(
cam
)
if
cam_max
is
None
:
cam_max
=
np
.
max
(
cam
)
cam
=
cam
-
cam_min
cam
=
cam
/
cam_max
cam
=
np
.
uint8
(
255
*
cam
)
cam_img
=
cv2
.
resize
(
cam
,
image_show
.
shape
[
0
:
2
],
interpolation
=
cv2
.
INTER_LINEAR
)
heatmap
=
cv2
.
applyColorMap
(
np
.
uint8
(
255
*
cam_img
),
cv2
.
COLORMAP_JET
)
heatmap
=
np
.
float32
(
heatmap
)
cam
=
heatmap
+
np
.
float32
(
image_show
)
cam
=
cam
/
np
.
max
(
cam
)
return
cam
def
avg_using_superpixels
(
features
,
segments
):
one_list
=
np
.
zeros
((
len
(
np
.
unique
(
segments
)),
features
.
shape
[
2
]))
for
x
in
np
.
unique
(
segments
):
one_list
[
x
]
=
np
.
mean
(
features
[
segments
==
x
],
axis
=
0
)
return
one_list
def
centroid_using_superpixels
(
features
,
segments
):
from
skimage.measure
import
regionprops
regions
=
regionprops
(
segments
+
1
)
one_list
=
np
.
zeros
((
len
(
np
.
unique
(
segments
)),
features
.
shape
[
2
]))
for
i
,
r
in
enumerate
(
regions
):
one_list
[
i
]
=
features
[
int
(
r
.
centroid
[
0
]
+
0.5
),
int
(
r
.
centroid
[
1
]
+
0.5
),
:]
# print(one_list.shape)
return
one_list
def
get_feature_for_kmeans
(
feature_map
,
segments
):
from
sklearn.preprocessing
import
normalize
centroid_feature
=
centroid_using_superpixels
(
feature_map
,
segments
)
avg_feature
=
avg_using_superpixels
(
feature_map
,
segments
)
x
=
np
.
concatenate
((
centroid_feature
,
avg_feature
),
axis
=-
1
)
x
=
normalize
(
x
)
return
x
def
combine_normlime_and_lime
(
lime_weights
,
g_weights
):
pred_labels
=
lime_weights
.
keys
()
combined_weights
=
{
y
:
[]
for
y
in
pred_labels
}
for
y
in
pred_labels
:
normlized_lime_weights_y
=
lime_weights
[
y
]
lime_weights_dict
=
{
tuple_w
[
0
]:
tuple_w
[
1
]
for
tuple_w
in
normlized_lime_weights_y
}
normlized_g_weight_y
=
g_weights
[
y
]
normlime_weights_dict
=
{
tuple_w
[
0
]:
tuple_w
[
1
]
for
tuple_w
in
normlized_g_weight_y
}
combined_weights
[
y
]
=
[
(
seg_k
,
lime_weights_dict
[
seg_k
]
*
normlime_weights_dict
[
seg_k
])
for
seg_k
in
lime_weights_dict
.
keys
()
]
combined_weights
[
y
]
=
sorted
(
combined_weights
[
y
],
key
=
lambda
x
:
np
.
abs
(
x
[
1
]),
reverse
=
True
)
return
combined_weights
def
save_fig
(
data_
,
save_outdir
,
algorithm_name
,
num_samples
=
3000
):
import
matplotlib.pyplot
as
plt
if
isinstance
(
data_
,
str
):
if
algorithm_name
==
'cam'
:
f_out
=
f
"
{
algorithm_name
}
_
{
data_
.
split
(
'/'
)[
-
1
]
}
.png"
else
:
f_out
=
f
"
{
algorithm_name
}
_
{
data_
.
split
(
'/'
)[
-
1
]
}
_s
{
num_samples
}
.png"
plt
.
savefig
(
os
.
path
.
join
(
save_outdir
,
f_out
)
)
else
:
n
=
0
if
algorithm_name
==
'cam'
:
f_out
=
f
'cam-
{
n
}
.png'
else
:
f_out
=
f
'
{
algorithm_name
}
_s
{
num_samples
}
-
{
n
}
.png'
while
os
.
path
.
exists
(
os
.
path
.
join
(
save_outdir
,
f_out
)
):
n
+=
1
if
algorithm_name
==
'cam'
:
f_out
=
f
'cam-
{
n
}
.png'
else
:
f_out
=
f
'
{
algorithm_name
}
_s
{
num_samples
}
-
{
n
}
.png'
continue
plt
.
savefig
(
os
.
path
.
join
(
save_outdir
,
f_out
)
)
paddlex/cv/models/explanation/core/lime_base.py
0 → 100644
浏览文件 @
011cff21
"""
Contains abstract functionality for learning locally linear sparse model.
"""
from
__future__
import
print_function
import
numpy
as
np
import
scipy
as
sp
import
sklearn
import
sklearn.preprocessing
from
skimage.color
import
gray2rgb
from
sklearn.linear_model
import
Ridge
,
lars_path
from
sklearn.utils
import
check_random_state
import
copy
from
functools
import
partial
from
skimage.segmentation
import
quickshift
from
skimage.measure
import
regionprops
class
LimeBase
(
object
):
"""Class for learning a locally linear sparse model from perturbed data"""
def
__init__
(
self
,
kernel_fn
,
verbose
=
False
,
random_state
=
None
):
"""Init function
Args:
kernel_fn: function that transforms an array of distances into an
array of proximity values (floats).
verbose: if true, print local prediction values from linear model.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
self
.
kernel_fn
=
kernel_fn
self
.
verbose
=
verbose
self
.
random_state
=
check_random_state
(
random_state
)
@
staticmethod
def
generate_lars_path
(
weighted_data
,
weighted_labels
):
"""Generates the lars path for weighted data.
Args:
weighted_data: data that has been weighted by kernel
weighted_label: labels, weighted by kernel
Returns:
(alphas, coefs), both are arrays corresponding to the
regularization parameter and coefficients, respectively
"""
x_vector
=
weighted_data
alphas
,
_
,
coefs
=
lars_path
(
x_vector
,
weighted_labels
,
method
=
'lasso'
,
verbose
=
False
)
return
alphas
,
coefs
def
forward_selection
(
self
,
data
,
labels
,
weights
,
num_features
):
"""Iteratively adds features to the model"""
clf
=
Ridge
(
alpha
=
0
,
fit_intercept
=
True
,
random_state
=
self
.
random_state
)
used_features
=
[]
for
_
in
range
(
min
(
num_features
,
data
.
shape
[
1
])):
max_
=
-
100000000
best
=
0
for
feature
in
range
(
data
.
shape
[
1
]):
if
feature
in
used_features
:
continue
clf
.
fit
(
data
[:,
used_features
+
[
feature
]],
labels
,
sample_weight
=
weights
)
score
=
clf
.
score
(
data
[:,
used_features
+
[
feature
]],
labels
,
sample_weight
=
weights
)
if
score
>
max_
:
best
=
feature
max_
=
score
used_features
.
append
(
best
)
return
np
.
array
(
used_features
)
def
feature_selection
(
self
,
data
,
labels
,
weights
,
num_features
,
method
):
"""Selects features for the model. see explain_instance_with_data to
understand the parameters."""
if
method
==
'none'
:
return
np
.
array
(
range
(
data
.
shape
[
1
]))
elif
method
==
'forward_selection'
:
return
self
.
forward_selection
(
data
,
labels
,
weights
,
num_features
)
elif
method
==
'highest_weights'
:
clf
=
Ridge
(
alpha
=
0.01
,
fit_intercept
=
True
,
random_state
=
self
.
random_state
)
clf
.
fit
(
data
,
labels
,
sample_weight
=
weights
)
coef
=
clf
.
coef_
if
sp
.
sparse
.
issparse
(
data
):
coef
=
sp
.
sparse
.
csr_matrix
(
clf
.
coef_
)
weighted_data
=
coef
.
multiply
(
data
[
0
])
# Note: most efficient to slice the data before reversing
sdata
=
len
(
weighted_data
.
data
)
argsort_data
=
np
.
abs
(
weighted_data
.
data
).
argsort
()
# Edge case where data is more sparse than requested number of feature importances
# In that case, we just pad with zero-valued features
if
sdata
<
num_features
:
nnz_indexes
=
argsort_data
[::
-
1
]
indices
=
weighted_data
.
indices
[
nnz_indexes
]
num_to_pad
=
num_features
-
sdata
indices
=
np
.
concatenate
((
indices
,
np
.
zeros
(
num_to_pad
,
dtype
=
indices
.
dtype
)))
indices_set
=
set
(
indices
)
pad_counter
=
0
for
i
in
range
(
data
.
shape
[
1
]):
if
i
not
in
indices_set
:
indices
[
pad_counter
+
sdata
]
=
i
pad_counter
+=
1
if
pad_counter
>=
num_to_pad
:
break
else
:
nnz_indexes
=
argsort_data
[
sdata
-
num_features
:
sdata
][::
-
1
]
indices
=
weighted_data
.
indices
[
nnz_indexes
]
return
indices
else
:
weighted_data
=
coef
*
data
[
0
]
feature_weights
=
sorted
(
zip
(
range
(
data
.
shape
[
1
]),
weighted_data
),
key
=
lambda
x
:
np
.
abs
(
x
[
1
]),
reverse
=
True
)
return
np
.
array
([
x
[
0
]
for
x
in
feature_weights
[:
num_features
]])
elif
method
==
'lasso_path'
:
weighted_data
=
((
data
-
np
.
average
(
data
,
axis
=
0
,
weights
=
weights
))
*
np
.
sqrt
(
weights
[:,
np
.
newaxis
]))
weighted_labels
=
((
labels
-
np
.
average
(
labels
,
weights
=
weights
))
*
np
.
sqrt
(
weights
))
nonzero
=
range
(
weighted_data
.
shape
[
1
])
_
,
coefs
=
self
.
generate_lars_path
(
weighted_data
,
weighted_labels
)
for
i
in
range
(
len
(
coefs
.
T
)
-
1
,
0
,
-
1
):
nonzero
=
coefs
.
T
[
i
].
nonzero
()[
0
]
if
len
(
nonzero
)
<=
num_features
:
break
used_features
=
nonzero
return
used_features
elif
method
==
'auto'
:
if
num_features
<=
6
:
n_method
=
'forward_selection'
else
:
n_method
=
'highest_weights'
return
self
.
feature_selection
(
data
,
labels
,
weights
,
num_features
,
n_method
)
def
explain_instance_with_data
(
self
,
neighborhood_data
,
neighborhood_labels
,
distances
,
label
,
num_features
,
feature_selection
=
'auto'
,
model_regressor
=
None
):
"""Takes perturbed data, labels and distances, returns explanation.
Args:
neighborhood_data: perturbed data, 2d array. first element is
assumed to be the original data point.
neighborhood_labels: corresponding perturbed labels. should have as
many columns as the number of possible labels.
distances: distances to original data point.
label: label for which we want an explanation
num_features: maximum number of features in explanation
feature_selection: how to select num_features. options are:
'forward_selection': iteratively add features to the model.
This is costly when num_features is high
'highest_weights': selects the features that have the highest
product of absolute weight * original data point when
learning with all the features
'lasso_path': chooses features based on the lasso
regularization path
'none': uses all features, ignores num_features
'auto': uses forward_selection if num_features <= 6, and
'highest_weights' otherwise.
model_regressor: sklearn regressor to use in explanation.
Defaults to Ridge regression if None. Must have
model_regressor.coef_ and 'sample_weight' as a parameter
to model_regressor.fit()
Returns:
(intercept, exp, score, local_pred):
intercept is a float.
exp is a sorted list of tuples, where each tuple (x,y) corresponds
to the feature id (x) and the local weight (y). The list is sorted
by decreasing absolute value of y.
score is the R^2 value of the returned explanation
local_pred is the prediction of the explanation model on the original instance
"""
weights
=
self
.
kernel_fn
(
distances
)
labels_column
=
neighborhood_labels
[:,
label
]
used_features
=
self
.
feature_selection
(
neighborhood_data
,
labels_column
,
weights
,
num_features
,
feature_selection
)
if
model_regressor
is
None
:
model_regressor
=
Ridge
(
alpha
=
1
,
fit_intercept
=
True
,
random_state
=
self
.
random_state
)
easy_model
=
model_regressor
easy_model
.
fit
(
neighborhood_data
[:,
used_features
],
labels_column
,
sample_weight
=
weights
)
prediction_score
=
easy_model
.
score
(
neighborhood_data
[:,
used_features
],
labels_column
,
sample_weight
=
weights
)
local_pred
=
easy_model
.
predict
(
neighborhood_data
[
0
,
used_features
].
reshape
(
1
,
-
1
))
if
self
.
verbose
:
print
(
'Intercept'
,
easy_model
.
intercept_
)
print
(
'Prediction_local'
,
local_pred
,)
print
(
'Right:'
,
neighborhood_labels
[
0
,
label
])
return
(
easy_model
.
intercept_
,
sorted
(
zip
(
used_features
,
easy_model
.
coef_
),
key
=
lambda
x
:
np
.
abs
(
x
[
1
]),
reverse
=
True
),
prediction_score
,
local_pred
)
class
ImageExplanation
(
object
):
def
__init__
(
self
,
image
,
segments
):
"""Init function.
Args:
image: 3d numpy array
segments: 2d numpy array, with the output from skimage.segmentation
"""
self
.
image
=
image
self
.
segments
=
segments
self
.
intercept
=
{}
self
.
local_exp
=
{}
self
.
local_pred
=
None
def
get_image_and_mask
(
self
,
label
,
positive_only
=
True
,
negative_only
=
False
,
hide_rest
=
False
,
num_features
=
5
,
min_weight
=
0.
):
"""Init function.
Args:
label: label to explain
positive_only: if True, only take superpixels that positively contribute to
the prediction of the label.
negative_only: if True, only take superpixels that negatively contribute to
the prediction of the label. If false, and so is positive_only, then both
negativey and positively contributions will be taken.
Both can't be True at the same time
hide_rest: if True, make the non-explanation part of the return
image gray
num_features: number of superpixels to include in explanation
min_weight: minimum weight of the superpixels to include in explanation
Returns:
(image, mask), where image is a 3d numpy array and mask is a 2d
numpy array that can be used with
skimage.segmentation.mark_boundaries
"""
if
label
not
in
self
.
local_exp
:
raise
KeyError
(
'Label not in explanation'
)
if
positive_only
&
negative_only
:
raise
ValueError
(
"Positive_only and negative_only cannot be true at the same time."
)
segments
=
self
.
segments
image
=
self
.
image
exp
=
self
.
local_exp
[
label
]
mask
=
np
.
zeros
(
segments
.
shape
,
segments
.
dtype
)
if
hide_rest
:
temp
=
np
.
zeros
(
self
.
image
.
shape
)
else
:
temp
=
self
.
image
.
copy
()
if
positive_only
:
fs
=
[
x
[
0
]
for
x
in
exp
if
x
[
1
]
>
0
and
x
[
1
]
>
min_weight
][:
num_features
]
if
negative_only
:
fs
=
[
x
[
0
]
for
x
in
exp
if
x
[
1
]
<
0
and
abs
(
x
[
1
])
>
min_weight
][:
num_features
]
if
positive_only
or
negative_only
:
for
f
in
fs
:
temp
[
segments
==
f
]
=
image
[
segments
==
f
].
copy
()
mask
[
segments
==
f
]
=
1
return
temp
,
mask
else
:
for
f
,
w
in
exp
[:
num_features
]:
if
np
.
abs
(
w
)
<
min_weight
:
continue
c
=
0
if
w
<
0
else
1
mask
[
segments
==
f
]
=
-
1
if
w
<
0
else
1
temp
[
segments
==
f
]
=
image
[
segments
==
f
].
copy
()
temp
[
segments
==
f
,
c
]
=
np
.
max
(
image
)
return
temp
,
mask
def
get_rendered_image
(
self
,
label
,
min_weight
=
0.005
):
"""
Args:
label: label to explain
min_weight:
Returns:
image, is a 3d numpy array
"""
if
label
not
in
self
.
local_exp
:
raise
KeyError
(
'Label not in explanation'
)
from
matplotlib
import
cm
segments
=
self
.
segments
image
=
self
.
image
exp
=
self
.
local_exp
[
label
]
temp
=
np
.
zeros_like
(
image
)
weight_max
=
abs
(
exp
[
0
][
1
])
exp
=
[(
f
,
w
/
weight_max
)
for
f
,
w
in
exp
]
exp
=
sorted
(
exp
,
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# negatives are at last.
cmaps
=
cm
.
get_cmap
(
'Spectral'
)
# sigmoid_space = 1 / (1 + np.exp(-np.linspace(-20, 20, len(exp))))
colors
=
cmaps
(
np
.
linspace
(
0
,
1
,
len
(
exp
)))
colors
=
colors
[:,
:
3
]
for
i
,
(
f
,
w
)
in
enumerate
(
exp
):
if
np
.
abs
(
w
)
<
min_weight
:
continue
temp
[
segments
==
f
]
=
image
[
segments
==
f
].
copy
()
temp
[
segments
==
f
]
=
colors
[
i
]
*
255
return
temp
class
LimeImageExplainer
(
object
):
"""Explains predictions on Image (i.e. matrix) data.
For numerical features, perturb them by sampling from a Normal(0,1) and
doing the inverse operation of mean-centering and scaling, according to the
means and stds in the training data. For categorical features, perturb by
sampling according to the training distribution, and making a binary
feature that is 1 when the value is the same as the instance being
explained."""
def
__init__
(
self
,
kernel_width
=
.
25
,
kernel
=
None
,
verbose
=
False
,
feature_selection
=
'auto'
,
random_state
=
None
):
"""Init function.
Args:
kernel_width: kernel width for the exponential kernel.
If None, defaults to sqrt(number of columns) * 0.75.
kernel: similarity kernel that takes euclidean distances and kernel
width as input and outputs weights in (0,1). If None, defaults to
an exponential kernel.
verbose: if true, print local prediction values from linear model
feature_selection: feature selection method. can be
'forward_selection', 'lasso_path', 'none' or 'auto'.
See function 'explain_instance_with_data' in lime_base.py for
details on what each of the options does.
random_state: an integer or numpy.RandomState that will be used to
generate random numbers. If None, the random state will be
initialized using the internal numpy seed.
"""
kernel_width
=
float
(
kernel_width
)
if
kernel
is
None
:
def
kernel
(
d
,
kernel_width
):
return
np
.
sqrt
(
np
.
exp
(
-
(
d
**
2
)
/
kernel_width
**
2
))
kernel_fn
=
partial
(
kernel
,
kernel_width
=
kernel_width
)
self
.
random_state
=
check_random_state
(
random_state
)
self
.
feature_selection
=
feature_selection
self
.
base
=
LimeBase
(
kernel_fn
,
verbose
,
random_state
=
self
.
random_state
)
def
explain_instance
(
self
,
image
,
classifier_fn
,
labels
=
(
1
,),
hide_color
=
None
,
num_features
=
100000
,
num_samples
=
1000
,
batch_size
=
10
,
distance_metric
=
'cosine'
,
model_regressor
=
None
):
"""Generates explanations for a prediction.
First, we generate neighborhood data by randomly perturbing features
from the instance (see __data_inverse). We then learn locally weighted
linear models on this neighborhood data to explain each of the classes
in an interpretable way (see lime_base.py).
Args:
image: 3 dimension RGB image. If this is only two dimensional,
we will assume it's a grayscale image and call gray2rgb.
classifier_fn: classifier prediction probability function, which
takes a numpy array and outputs prediction probabilities. For
ScikitClassifiers , this is classifier.predict_proba.
labels: iterable with labels to be explained.
hide_color: TODO
num_features: maximum number of features present in explanation
num_samples: size of the neighborhood to learn the linear model
batch_size: TODO
distance_metric: the distance metric to use for weights.
model_regressor: sklearn regressor to use in explanation. Defaults
to Ridge regression in LimeBase. Must have model_regressor.coef_
and 'sample_weight' as a parameter to model_regressor.fit()
Returns:
An ImageExplanation object (see lime_image.py) with the corresponding
explanations.
"""
if
len
(
image
.
shape
)
==
2
:
image
=
gray2rgb
(
image
)
try
:
segments
=
quickshift
(
image
,
sigma
=
1
)
except
ValueError
as
e
:
raise
e
self
.
segments
=
segments
fudged_image
=
image
.
copy
()
if
hide_color
is
None
:
# if no hide_color, use the mean
for
x
in
np
.
unique
(
segments
):
mx
=
np
.
mean
(
image
[
segments
==
x
],
axis
=
0
)
fudged_image
[
segments
==
x
]
=
mx
elif
hide_color
==
'avg_from_neighbor'
:
from
scipy.spatial.distance
import
cdist
n_features
=
np
.
unique
(
segments
).
shape
[
0
]
regions
=
regionprops
(
segments
+
1
)
centroids
=
np
.
zeros
((
n_features
,
2
))
for
i
,
x
in
enumerate
(
regions
):
centroids
[
i
]
=
np
.
array
(
x
.
centroid
)
d
=
cdist
(
centroids
,
centroids
,
'sqeuclidean'
)
for
x
in
np
.
unique
(
segments
):
# print(np.argmin(d[x]))
a
=
[
image
[
segments
==
i
]
for
i
in
np
.
argsort
(
d
[
x
])[
1
:
6
]]
mx
=
np
.
mean
(
np
.
concatenate
(
a
),
axis
=
0
)
fudged_image
[
segments
==
x
]
=
mx
else
:
fudged_image
[:]
=
0
top
=
labels
data
,
labels
=
self
.
data_labels
(
image
,
fudged_image
,
segments
,
classifier_fn
,
num_samples
,
batch_size
=
batch_size
)
distances
=
sklearn
.
metrics
.
pairwise_distances
(
data
,
data
[
0
].
reshape
(
1
,
-
1
),
metric
=
distance_metric
).
ravel
()
ret_exp
=
ImageExplanation
(
image
,
segments
)
for
label
in
top
:
(
ret_exp
.
intercept
[
label
],
ret_exp
.
local_exp
[
label
],
ret_exp
.
score
,
ret_exp
.
local_pred
)
=
self
.
base
.
explain_instance_with_data
(
data
,
labels
,
distances
,
label
,
num_features
,
model_regressor
=
model_regressor
,
feature_selection
=
self
.
feature_selection
)
return
ret_exp
def
data_labels
(
self
,
image
,
fudged_image
,
segments
,
classifier_fn
,
num_samples
,
batch_size
=
10
):
"""Generates images and predictions in the neighborhood of this image.
Args:
image: 3d numpy array, the image
fudged_image: 3d numpy array, image to replace original image when
superpixel is turned off
segments: segmentation of the image
classifier_fn: function that takes a list of images and returns a
matrix of prediction probabilities
num_samples: size of the neighborhood to learn the linear model
batch_size: classifier_fn will be called on batches of this size.
Returns:
A tuple (data, labels), where:
data: dense num_samples * num_superpixels
labels: prediction probabilities matrix
"""
n_features
=
np
.
unique
(
segments
).
shape
[
0
]
data
=
self
.
random_state
.
randint
(
0
,
2
,
num_samples
*
n_features
)
\
.
reshape
((
num_samples
,
n_features
))
labels
=
[]
data
[
0
,
:]
=
1
imgs
=
[]
for
row
in
data
:
temp
=
copy
.
deepcopy
(
image
)
zeros
=
np
.
where
(
row
==
0
)[
0
]
mask
=
np
.
zeros
(
segments
.
shape
).
astype
(
bool
)
for
z
in
zeros
:
mask
[
segments
==
z
]
=
True
temp
[
mask
]
=
fudged_image
[
mask
]
imgs
.
append
(
temp
)
if
len
(
imgs
)
==
batch_size
:
preds
=
classifier_fn
(
np
.
array
(
imgs
))
labels
.
extend
(
preds
)
imgs
=
[]
if
len
(
imgs
)
>
0
:
preds
=
classifier_fn
(
np
.
array
(
imgs
))
labels
.
extend
(
preds
)
return
data
,
np
.
array
(
labels
)
paddlex/cv/models/explanation/visualize.py
0 → 100644
浏览文件 @
011cff21
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
copy
import
os.path
as
osp
import
numpy
as
np
from
.core.explanation
import
Explanation
def
visualize
(
img_file
,
model
,
explanation_type
=
'lime'
,
num_samples
=
3000
,
batch_size
=
50
,
save_dir
=
'./'
):
model
.
arrange_transforms
(
transforms
=
model
.
test_transforms
,
mode
=
'test'
)
tmp_transforms
=
copy
.
deepcopy
(
model
.
test_transforms
)
tmp_transforms
.
transforms
=
tmp_transforms
.
transforms
[:
-
2
]
img
=
tmp_transforms
(
img_file
)[
0
]
img
=
np
.
around
(
img
).
astype
(
'uint8'
)
img
=
np
.
expand_dims
(
img
,
axis
=
0
)
explaier
=
None
if
explanation_type
==
'lime'
:
explaier
=
get_lime_explaier
(
img
,
model
,
num_samples
=
num_samples
,
batch_size
=
batch_size
)
else
:
raise
Exception
(
'The {} explanantion method is not supported yet!'
.
format
(
explanation_type
))
img_name
=
osp
.
splitext
(
osp
.
split
(
img_file
)[
-
1
])[
0
]
explaier
.
explain
(
img
,
save_dir
=
save_dir
)
def
get_lime_explaier
(
img
,
model
,
num_samples
=
3000
,
batch_size
=
50
):
def
predict_func
(
image
):
image
=
image
.
astype
(
'float32'
)
model
.
test_transforms
.
transforms
=
model
.
test_transforms
.
transforms
[
-
2
:]
out
=
model
.
explanation_predict
(
image
)
return
out
[
0
]
explaier
=
Explanation
(
'lime'
,
predict_func
,
num_samples
=
num_samples
,
batch_size
=
batch_size
)
return
explaier
\ No newline at end of file
paddlex/cv/nets/resnet.py
浏览文件 @
011cff21
...
@@ -120,6 +120,7 @@ class ResNet(object):
...
@@ -120,6 +120,7 @@ class ResNet(object):
self
.
num_classes
=
num_classes
self
.
num_classes
=
num_classes
self
.
lr_mult_list
=
lr_mult_list
self
.
lr_mult_list
=
lr_mult_list
self
.
curr_stage
=
0
self
.
curr_stage
=
0
self
.
features
=
[]
def
_conv_offset
(
self
,
def
_conv_offset
(
self
,
input
,
input
,
...
@@ -474,7 +475,9 @@ class ResNet(object):
...
@@ -474,7 +475,9 @@ class ResNet(object):
size
=
self
.
num_classes
,
size
=
self
.
num_classes
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
return
out
self
.
features
.
append
(
out
)
# out.persistable=True
return
out
,
self
.
features
return
OrderedDict
([(
'res{}_sum'
.
format
(
self
.
feature_maps
[
idx
]),
feat
)
return
OrderedDict
([(
'res{}_sum'
.
format
(
self
.
feature_maps
[
idx
]),
feat
)
for
idx
,
feat
in
enumerate
(
res_endpoints
)])
for
idx
,
feat
in
enumerate
(
res_endpoints
)])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录