提交 354ca296 编写于 作者: W wanghaoshuang

Add sensitive API.

上级 a0d17e44
......@@ -15,6 +15,9 @@ import flops as flops_module
from flops import *
import model_size as model_size_module
from model_size import *
import sensitive
from sensitive import *
__all__ = []
__all__ += flops_module.__all__
__all__ += model_size_module.__all__
__all__ += sensitive.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ..core import GraphWrapper
__all__ = ["sensitivity"]
def sensitivity(program,
scope,
param_names,
eval_func,
sensitivities_file=None):
graph = GraphWrapper(program)
if sensitivities_file is not None:
assert os.path.exsits(sensitivities_file)
sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names:
if name not in sensitivities:
size = graph.var(name).shape()[0]
sensitivities[name] = {
'pruned_percent': [],
'loss': [],
'size': size
}
baseline = None
for name in sensitivities:
ratio = step_size
while ratio < 1:
ratio = round(ratio, 2)
if ratio in sensitivities[param]['pruned_percent']:
_logger.debug('{}, {} has computed.'.format(name, ratio))
ratio += step_size
continue
if baseline is None:
baseline = _eval_func(grpah.program, scope)
param_backup = {}
pruner = Pruner()
pruned_program = pruner.prune(
program=graph.program,
scope=scope,
params=[name],
ratios=[ratio],
place=place,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = _eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
sensitivities[name]['pruned_percent'].append(ratio)
sensitivities[name]['loss'].append(loss)
_save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
ratio += step_size
return sensitivities
def _load_sensitivities(sensitivities_file):
"""
Load sensitivities from file.
"""
sensitivities = {}
if sensitivities_file and os.path.exists(sensitivities_file):
with open(sensitivities_file, 'rb') as f:
if sys.version_info < (3, 0):
sensitivities = pickle.load(f)
else:
sensitivities = pickle.load(f, encoding='bytes')
for param in sensitivities:
sensitivities[param]['pruned_percent'] = [
round(p, 2) for p in sensitivities[param]['pruned_percent']
]
self._format_sensitivities(sensitivities)
return sensitivities
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Controllers and controller server"""
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddleslim.analysis import sensitivity
from layers import conv_bn_layer
class TestSensitivity(unittest.TestCase):
def test_sensitivity(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 1, 28, 28])
label = fluid.data(name="label", shape=[None, 1], dtype="int64")
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
out = fluid.layers.fc(conv6, size=10, act='softmax')
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
eval_program = main_program.clone(for_test=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
def eval_func(program, scope):
feeder = fluid.DataFeeder(
feed_list=['image', 'label'], place=place, program=program)
acc_set = []
for data in val_reader():
acc_np = exe.run(program=program,
scope=scope,
feed=feeder.feed(data),
fetch_list=[acc_top1])
acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean()
print("acc_val_mean: {}".format(acc_val_mean))
return acc_val_mean
sensitivity(eval_program,
fluid.global_scope(), place, ["conv4_weights"], eval_func,
"./sensitivities_file")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册