提交 de55821b 编写于 作者: C ceci3

Merge branch 'develop' of ssh://gitlab.baidu.com:8022/PaddlePaddle/PaddleSlim into fix_nas

......@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=100,
max_ratios=0.9,
min_ratios=0.,
is_server=True,
key="auto_pruner")
while True:
......
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import SensitivePruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(
np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
def eval_func(program):
return test(0, program)
if args.data == "mnist":
train(0, fluid.default_main_program())
pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints)
pruned_program, pruned_val_program, iter = pruner.restore()
if pruned_program is None:
pruned_program = fluid.default_main_program()
if pruned_val_program is None:
pruned_val_program = val_program
start = iter
end = 6
for iter in range(start, end):
pruned_program, pruned_val_program = pruner.prune(
pruned_program, pruned_val_program, params, 0.1)
train(iter, pruned_program)
test(iter, pruned_val_program)
pruner.save_checkpoint(pruned_program, pruned_val_program)
print("before flops: {}".format(flops(fluid.default_main_program())))
print("after flops: {}".format(flops(pruned_val_program)))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
......@@ -17,6 +17,7 @@ import os
import logging
import pickle
import numpy as np
import paddle.fluid as fluid
from ..core import GraphWrapper
from ..common import get_logger
from ..prune import Pruner
......@@ -27,13 +28,12 @@ __all__ = ["sensitivity"]
def sensitivity(program,
scope,
place,
param_names,
eval_func,
sensitivities_file=None,
step_size=0.2):
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file)
......@@ -55,7 +55,7 @@ def sensitivity(program,
ratio += step_size
continue
if baseline is None:
baseline = eval_func(graph.program, scope)
baseline = eval_func(graph.program)
param_backup = {}
pruner = Pruner()
......@@ -68,7 +68,7 @@ def sensitivity(program,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = eval_func(pruned_program, scope)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
......@@ -81,7 +81,7 @@ def sensitivity(program,
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
ratio += step_size
return sensitivities
return sensitivities
def _load_sensitivities(sensitivities_file):
......
......@@ -23,6 +23,8 @@ import controller_client
from controller_client import *
import lock_utils
from lock_utils import *
import cached_reader as cached_reader_module
from cached_reader import *
__all__ = []
__all__ += controller.__all__
......@@ -30,3 +32,4 @@ __all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += lock_utils.__all__
__all__ += cached_reader_module.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import numpy as np
from .log_helper import get_logger
__all__ = ['cached_reader']
_logger = get_logger(__name__, level=logging.INFO)
def cached_reader(reader, sampled_rate, cache_path, cached_id):
"""
Sample partial data from reader and cache them into local file system.
Args:
reader: Iterative data source.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
cache_path(str): The path to cache the sampled data.
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
"""
np.random.seed(cached_id)
cache_path = os.path.join(cache_path, str(cached_id))
_logger.debug('read data from: {}'.format(cache_path))
def s_reader():
if os.path.isdir(cache_path):
for file_name in open(os.path.join(cache_path, "list")):
yield np.load(
os.path.join(cache_path, file_name.strip()),
allow_pickle=True)
else:
os.makedirs(cache_path)
list_file = open(os.path.join(cache_path, "list"), 'w')
batch = 0
dtype = None
for data in reader():
if batch == 0 or (np.random.uniform() < sampled_rate):
np.save(
os.path.join(cache_path, 'batch' + str(batch)), data)
list_file.write('batch' + str(batch) + '.npy\n')
batch += 1
yield data
return s_reader
......@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
......@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
iter).encode())
response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok":
return True
......
......@@ -51,23 +51,8 @@ class ControllerServer(object):
self._port = address[1]
self._ip = address[0]
self._key = key
self._socket_file = "./controller_server.socket"
def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
......@@ -82,7 +67,6 @@ class ControllerServer(object):
def close(self):
"""Close the server."""
self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!")
def port(self):
......@@ -109,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
self._controller.update(tokens, float(reward), int(iter))
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
......
......@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300,
max_try_times=None,
init_tokens=None,
constrain_func=None):
"""Initialize.
......@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
......@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2)
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._max_try_times = max_try_times
self._reward = -1
self._tokens = init_tokens
self._constrain_func = constrain_func
......@@ -65,15 +65,17 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
iter = int(iter)
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
......@@ -99,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
if self._constrain_func is None or self._max_try_times is None:
return new_tokens
for _ in range(self._max_iter_number):
for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:]
......
......@@ -15,6 +15,7 @@
import socket
import logging
import numpy as np
import hashlib
import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController
......@@ -58,38 +59,40 @@ class SANAS(object):
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
self._keys = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server
if self._is_server:
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=self._key)
self._controller_server.start()
self._controller_client = ControllerClient(
self._controller_server.ip(),
self._controller_server.port(),
key=key)
server_ip, server_port, key=self._key)
self._iter = 0
......@@ -115,4 +118,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure.
"""
self._iter += 1
return self._controller_client.update(self._current_tokens, score)
return self._controller_client.update(self._current_tokens, score,
self._iter)
......@@ -19,9 +19,12 @@ import controller_server
from controller_server import *
import controller_client
from controller_client import *
import sensitive_pruner
from sensitive_pruner import *
__all__ = []
__all__ += pruner.__all__
__all__ += auto_pruner.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__
......@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_number=300,
max_try_times=300,
max_client_num=10,
search_steps=300,
max_ratios=[0.9],
......@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
......@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._max_try_times = max_try_times
self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios)
......@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr
......@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope)
self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score)
self._controller_client.update(tokens, score, self._iter)
self._iter += 1
def _restore(self, scope):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import copy
from scipy.optimize import leastsq
import numpy as np
import paddle.fluid as fluid
from ..common import get_logger
from ..analysis import sensitivity
from ..analysis import flops
from .pruner import Pruner
__all__ = ["SensitivePruner"]
_logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step.
Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute.
eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program.
scope(fluid.scope): The scope used to execute program.
"""
self._eval_func = eval_func
self._iter = 0
self._place = place
self._scope = fluid.global_scope() if scope is None else scope
self._pruner = Pruner()
self._checkpoints = checkpoints
def save_checkpoint(self, train_program, eval_program):
checkpoint = os.path.join(self._checkpoints, str(self._iter - 1))
exe = fluid.Executor(self._place)
fluid.io.save_persistables(
exe, checkpoint, main_program=train_program, filename="__params__")
with open(checkpoint + "/main_program", "wb") as f:
f.write(train_program.desc.serialize_to_string())
with open(checkpoint + "/eval_program", "wb") as f:
f.write(eval_program.desc.serialize_to_string())
def restore(self, checkpoints=None):
exe = fluid.Executor(self._place)
checkpoints = self._checkpoints if checkpoints is None else checkpoints
print("check points: {}".format(checkpoints))
main_program = None
eval_program = None
if checkpoints is not None:
cks = [dir for dir in os.listdir(checkpoints)]
if len(cks) > 0:
latest = max([int(ck) for ck in cks])
latest_ck_path = os.path.join(checkpoints, str(latest))
self._iter += 1
with open(latest_ck_path + "/main_program", "rb") as f:
program_desc_str = f.read()
main_program = fluid.Program.parse_from_string(
program_desc_str)
print main_program
with open(latest_ck_path + "/eval_program", "rb") as f:
program_desc_str = f.read()
eval_program = fluid.Program.parse_from_string(
program_desc_str)
with fluid.scope_guard(self._scope):
fluid.io.load_persistables(exe, latest_ck_path,
main_program, "__params__")
print("load checkpoint from: {}".format(latest_ck_path))
print("flops of eval program: {}".format(flops(eval_program)))
return main_program, eval_program, self._iter
def prune(self, train_program, eval_program, params, pruned_flops):
"""
Pruning parameters of training and evaluation network by sensitivities in current step.
Args:
train_program(fluid.Program): The training program to be pruned.
eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
params(list<str>): The parameters to be pruned.
pruned_flops(float): The ratio of FLOPS to be pruned in current step.
Return:
tuple: A tuple of pruned training program and pruned evaluation program.
"""
_logger.info("Pruning: {}".format(params))
sensitivities_file = "sensitivities_iter{}.data".format(self._iter)
with fluid.scope_guard(self._scope):
sensitivities = sensitivity(
eval_program,
self._place,
params,
self._eval_func,
sensitivities_file=sensitivities_file,
step_size=0.1)
print sensitivities
_, ratios = self._get_ratios_by_sensitive(sensitivities, pruned_flops,
eval_program)
pruned_program = self._pruner.prune(
train_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=False)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
eval_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=True)
self._iter += 1
return pruned_program, pruned_val_program
def _get_ratios_by_sensitive(self, sensitivities, pruned_flops,
eval_program):
"""
Search a group of ratios for pruning target flops.
"""
def func(params, x):
a, b, c, d = params
return a * x * x * x + b * x * x + c * x + d
def error(params, x, y):
return func(params, x) - y
def slove_coefficient(x, y):
init_coefficient = [10, 10, 10, 10]
coefficient, loss = leastsq(error, init_coefficient, args=(x, y))
return coefficient
min_loss = 0.
max_loss = 0.
# step 1: fit curve by sensitivities
coefficients = {}
for param in sensitivities:
losses = np.array([0] * 5 + sensitivities[param]['loss'])
precents = np.array([0] * 5 + sensitivities[param][
'pruned_percent'])
coefficients[param] = slove_coefficient(precents, losses)
loss = np.max(losses)
max_loss = np.max([max_loss, loss])
# step 2: Find a group of ratios by binary searching.
base_flops = flops(eval_program)
ratios = []
max_times = 20
while min_loss < max_loss and max_times > 0:
loss = (max_loss + min_loss) / 2
_logger.info(
'-----------Try pruned ratios while acc loss={}-----------'.
format(loss))
ratios = []
# step 2.1: Get ratios according to current loss
for param in sensitivities:
coefficient = copy.deepcopy(coefficients[param])
coefficient[-1] = coefficient[-1] - loss
roots = np.roots(coefficient)
for root in roots:
min_root = 1
if np.isreal(root) and root > 0 and root < 1:
selected_root = min(root.real, min_root)
ratios.append(selected_root)
_logger.info('Pruned ratios={}'.format(
[round(ratio, 3) for ratio in ratios]))
# step 2.2: Pruning by current ratios
param_shape_backup = {}
pruned_program = self._pruner.prune(
eval_program,
None, # scope
sensitivities.keys(),
ratios,
None, # place
only_graph=True)
pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
_logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))
# step 2.3: Check whether current ratios is enough
if abs(pruned_ratio - pruned_flops) < 0.015:
break
if pruned_ratio > pruned_flops:
max_loss = loss
else:
min_loss = loss
max_times -= 1
return sensitivities.keys(), ratios
......@@ -20,6 +20,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
......@@ -186,19 +187,68 @@ def quant_aware(program, place, config, scope=None, for_test=False):
return quant_program
def quant_post(program, place, config, scope=None):
def quant_post(executor,
model_dir,
quantize_model_path,
sample_generator,
model_filename=None,
params_filename=None,
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]):
"""
add quantization ops in program. the program returned is not trainable.
The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
program(fluid.Program): program
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program.
Return:
fluid.Program: the quantization program is not trainable.
executor(fluid.Executor): The executor to load, run and save the
quantized model.
model_dir(str): The path of fp32 model that will be quantized, and
the model and params that saved by fluid.io.save_inference_model
are under the path.
quantize_model_path(str): The path to save quantized model using api
fluid.io.save_inference_model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default is 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader, default is 16.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use fluid.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max method to get the scale factor. Default is 'KL'.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
Returns:
None
"""
pass
post_training_quantization = PostTrainingQuantization(
executor=executor,
sample_generator=sample_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=False)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, place, config, scope=None, save_int8=False):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册