提交 6c503961 编写于 作者: S sunyanfang01

modify explanation

上级 b15670c7
......@@ -28,6 +28,7 @@ from . import seg
from . import cls
from . import slim
from . import tools
from . import explanation
try:
import pycocotools
......
......@@ -60,10 +60,12 @@ class BaseClassifier(BaseAPI):
if mode != 'test':
label = fluid.data(dtype='int64', shape=[None, 1], name='label')
model = getattr(paddlex.cv.nets, str.lower(self.model_name))
net_out, feat = model(image, num_classes=self.num_classes)
softmax_out = fluid.layers.softmax(net_out, use_cudnn=False)
net_out = model(image, num_classes=self.num_classes)
softmax_out = fluid.layers.softmax(net_out['logits'], use_cudnn=False)
inputs = OrderedDict([('image', image)])
outputs = OrderedDict([('predict', softmax_out), ('net_out', feat[-1])])
outputs = net_out
outputs.update({'predict': softmax_out})
outputs.move_to_end('predict', last=False)
if mode != 'test':
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
......
import os
def imagenet_val_files_and_labels(dataset_directory):
classes = open(os.path.join(dataset_directory, 'imagenet_lsvrc_2015_synsets.txt')).readlines()
class_to_indx = {classes[i].split('\n')[0]: i for i in range(len(classes))}
images_path = os.path.join(dataset_directory, 'val')
filenames = []
labels = []
lines = open(os.path.join(dataset_directory, 'imagenet_2012_validation_synset_labels.txt'), 'r').readlines()
for i, line in enumerate(lines):
class_name = line.split('\n')[0]
a = 'ILSVRC2012_val_%08d.JPEG' % (i + 1)
filenames.append(f'{images_path}/{a}')
labels.append(class_to_indx[class_name])
# print(filenames[-1], labels[-1])
return filenames, labels
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
def _find_classes(dir):
# Faster and available in Python 3.5 and above
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
\ No newline at end of file
return classes, class_to_idx
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import sys
import cv2
import numpy as np
import six
import glob
from as_data_reader.data_path_utils import _find_classes
from .data_path_utils import _find_classes
from PIL import Image
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import paddle.fluid as fluid
import numpy as np
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from .explanation_algorithms import CAM, LIME, NormLIME
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import numpy as np
import time
......
"""
Contains abstract functionality for learning locally linear sparse model.
"""
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import print_function
import numpy as np
import scipy as sp
......
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import cv2
......
......@@ -18,6 +18,7 @@ from __future__ import print_function
import six
import math
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -182,6 +183,6 @@ class DarkNet(object):
initializer=fluid.initializer.Uniform(-stdv, stdv),
name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
return OrderedDict([('logits', out)])
return blocks
......@@ -15,6 +15,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle
import paddle.fluid as fluid
import math
......@@ -96,7 +97,7 @@ class DenseNet(object):
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_weights"),
bias_attr=ParamAttr(name='fc_offset'))
return out
return OrderedDict([('logits', out)])
def make_transition(self, input, num_output_features, name=None):
bn_ac = fluid.layers.batch_norm(
......
......@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
......@@ -196,7 +197,7 @@ class MobileNetV1(object):
param_attr=ParamAttr(
initializer=fluid.initializer.MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
return output
return OrderedDict([('logits', out)])
if not self.with_extra_blocks:
return blocks
......
......@@ -14,6 +14,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
......@@ -109,6 +110,7 @@ class MobileNetV2:
size=self.num_classes,
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
return OrderedDict([('logits', output)])
return output
def modify_bottle_params(self, output_stride=None):
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
......@@ -327,7 +328,7 @@ class MobileNetV3():
size=self.num_classes,
param_attr=ParamAttr(name='fc_weights'),
bias_attr=ParamAttr(name='fc_offset'))
return out
return OrderedDict([('logits', out)])
if not self.with_extra_blocks:
return blocks
......
......@@ -120,7 +120,6 @@ class ResNet(object):
self.num_classes = num_classes
self.lr_mult_list = lr_mult_list
self.curr_stage = 0
self.features = []
def _conv_offset(self,
input,
......@@ -475,8 +474,7 @@ class ResNet(object):
size=self.num_classes,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
self.features.append(out)
return out, self.features
return OrderedDict([('logits', out)])
return OrderedDict([('res{}_sum'.format(self.feature_maps[idx]), feat)
for idx, feat in enumerate(res_endpoints)])
......
......@@ -15,6 +15,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
......@@ -101,6 +102,7 @@ class ShuffleNetV2():
size=self.num_classes,
param_attr=ParamAttr(initializer=MSRA(), name='fc6_weights'),
bias_attr=ParamAttr(name='fc6_offset'))
return OrderedDict([('logits', output)])
return output
def conv_bn_layer(self,
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import contextlib
import paddle
import math
from collections import OrderedDict
import paddle.fluid as fluid
from .segmentation.model_utils.libs import scope, name_scope
from .segmentation.model_utils.libs import bn, bn_relu, relu
......@@ -104,7 +105,7 @@ class Xception():
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=fluid.param_attr.ParamAttr(name='bias'))
return out
return OrderedDict([('logits', out)])
else:
return data
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .cv.models.explanation import visualize
visualize = visualize.visualize
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册