提交 99599e57 编写于 作者: I itminner

Merge remote-tracking branch 'upstream/develop' into develop

...@@ -195,11 +195,12 @@ def compress(args): ...@@ -195,11 +195,12 @@ def compress(args):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=100, search_steps=100,
max_ratios=0.9, max_ratios=0.9,
min_ratios=0., min_ratios=0.,
is_server=True,
key="auto_pruner") key="auto_pruner")
while True: while True:
......
...@@ -39,7 +39,7 @@ def init_sa_nas(config): ...@@ -39,7 +39,7 @@ def init_sa_nas(config):
search_steps = 10000000 search_steps = 10000000
### start a server and a client ### start a server and a client
sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) sa_nas = SANAS(config, search_steps=search_steps, is_server=True)
### start a client, server_addr is server address ### start a client, server_addr is server address
#sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False)
......
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import SensitivePruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(
np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
def eval_func(program):
return test(0, program)
if args.data == "mnist":
train(0, fluid.default_main_program())
pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints)
pruned_program, pruned_val_program, iter = pruner.restore()
if pruned_program is None:
pruned_program = fluid.default_main_program()
if pruned_val_program is None:
pruned_val_program = val_program
start = iter
end = 6
for iter in range(start, end):
pruned_program, pruned_val_program = pruner.prune(
pruned_program, pruned_val_program, params, 0.1)
train(iter, pruned_program)
test(iter, pruned_val_program)
pruner.save_checkpoint(pruned_program, pruned_val_program)
print("before flops: {}".format(flops(fluid.default_main_program())))
print("after flops: {}".format(flops(pruned_val_program)))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
...@@ -15,9 +15,6 @@ import flops as flops_module ...@@ -15,9 +15,6 @@ import flops as flops_module
from flops import * from flops import *
import model_size as model_size_module import model_size as model_size_module
from model_size import * from model_size import *
import sensitive
from sensitive import *
__all__ = [] __all__ = []
__all__ += flops_module.__all__ __all__ += flops_module.__all__
__all__ += model_size_module.__all__ __all__ += model_size_module.__all__
__all__ += sensitive.__all__
...@@ -23,6 +23,8 @@ import controller_client ...@@ -23,6 +23,8 @@ import controller_client
from controller_client import * from controller_client import *
import lock_utils import lock_utils
from lock_utils import * from lock_utils import *
import cached_reader as cached_reader_module
from cached_reader import *
__all__ = [] __all__ = []
__all__ += controller.__all__ __all__ += controller.__all__
...@@ -30,3 +32,4 @@ __all__ += sa_controller.__all__ ...@@ -30,3 +32,4 @@ __all__ += sa_controller.__all__
__all__ += controller_server.__all__ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += lock_utils.__all__ __all__ += lock_utils.__all__
__all__ += cached_reader_module.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import numpy as np
from .log_helper import get_logger
__all__ = ['cached_reader']
_logger = get_logger(__name__, level=logging.INFO)
def cached_reader(reader, sampled_rate, cache_path, cached_id):
"""
Sample partial data from reader and cache them into local file system.
Args:
reader: Iterative data source.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
cache_path(str): The path to cache the sampled data.
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
"""
np.random.seed(cached_id)
cache_path = os.path.join(cache_path, str(cached_id))
_logger.debug('read data from: {}'.format(cache_path))
def s_reader():
if os.path.isdir(cache_path):
for file_name in open(os.path.join(cache_path, "list")):
yield np.load(
os.path.join(cache_path, file_name.strip()),
allow_pickle=True)
else:
os.makedirs(cache_path)
list_file = open(os.path.join(cache_path, "list"), 'w')
batch = 0
dtype = None
for data in reader():
if batch == 0 or (np.random.uniform() < sampled_rate):
np.save(
os.path.join(cache_path, 'batch' + str(batch)), data)
list_file.write('batch' + str(batch) + '.npy\n')
batch += 1
yield data
return s_reader
...@@ -38,7 +38,7 @@ class ControllerClient(object): ...@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -48,8 +48,8 @@ class ControllerClient(object): ...@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
.encode()) iter).encode())
response = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok": if response.strip('\n').split("\t") == "ok":
return True return True
......
...@@ -51,23 +51,8 @@ class ControllerServer(object): ...@@ -51,23 +51,8 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._socket_file = "./controller_server.socket"
def start(self): def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address) self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num) self._socket_server.listen(self._max_client_num)
...@@ -82,7 +67,6 @@ class ControllerServer(object): ...@@ -82,7 +67,6 @@ class ControllerServer(object):
def close(self): def close(self):
"""Close the server.""" """Close the server."""
self._closed = True self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!") _logger.info("server closed!")
def port(self): def port(self):
...@@ -109,14 +93,15 @@ class ControllerServer(object): ...@@ -109,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key): if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward), int(iter))
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController): ...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=1024,
max_iter_number=300, max_try_times=None,
init_tokens=None, init_tokens=None,
constrain_func=None): constrain_func=None):
"""Initialize. """Initialize.
...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController): ...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table. range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature. reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature. init_temperature(float): Init temperature.
max_iter_number(int): max iteration number. max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens. init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
""" """
...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController): ...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2) len(self._range_table) == 2)
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_iter_number = max_iter_number self._max_try_times = max_try_times
self._reward = -1 self._reward = -1
self._tokens = init_tokens self._tokens = init_tokens
self._constrain_func = constrain_func self._constrain_func = constrain_func
...@@ -65,14 +65,16 @@ class SAController(EvolutionaryController): ...@@ -65,14 +65,16 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key] d[key] = self.__dict__[key]
return d return d
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
""" """
self._iter += 1 iter = int(iter)
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
...@@ -99,9 +101,9 @@ class SAController(EvolutionaryController): ...@@ -99,9 +101,9 @@ class SAController(EvolutionaryController):
self._range_table[1][index] + 1) self._range_table[1][index] + 1)
_logger.debug("change index[{}] from {} to {}".format(index, tokens[ _logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index])) index], new_tokens[index]))
if self._constrain_func is None: if self._constrain_func is None or self._max_try_times is None:
return new_tokens return new_tokens
for _ in range(self._max_iter_number): for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens): if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
def merge(teacher_program,
student_program,
data_name_map,
place,
teacher_scope=fluid.global_scope(),
student_scope=fluid.global_scope(),
name_prefix='teacher_'):
"""
Merge teacher program into student program and add a uniform prefix to the
names of all vars in teacher program
Args:
teacher_program(Program): The input teacher model paddle program
student_program(Program): The input student model paddle program
data_map_map(dict): Describe the mapping between the teacher var name
and the student var name
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
paddle run on which device.
student_scope(Scope): The input student scope
teacher_scope(Scope): The input teacher scope
name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
"""
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
# student scope add var
student_scope_var = student_scope.var(teacher_var.name).get_tensor()
teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor()
student_scope_var.set(np.array(teacher_scope_var), place)
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
for block in teacher_program.blocks:
for op in block.ops:
if op.type != 'feed' and op.type != 'fetch':
inputs = {}
outputs = {}
attrs = {}
for input_name in op.input_names:
inputs[input_name] = [
block.var(in_var_name)
for in_var_name in op.input(input_name)
]
for output_name in op.output_names:
outputs[output_name] = [
block.var(out_var_name)
for out_var_name in op.output(output_name)
]
for attr_name in op.attr_names:
attrs[attr_name] = op.attr(attr_name)
student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
return student_program
def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
student_var2_name, program):
"""
Combine variables from student model and teacher model by fsp-loss.
Args:
teacher_var1_name(str): The name of teacher_var1.
teacher_var2_name(str): The name of teacher_var2. Except for the
second dimension, all other dimensions should
be consistent with teacher_var1.
student_var1_name(str): The name of student_var1.
student_var2_name(str): The name of student_var2. Except for the
second dimension, all other dimensions should
be consistent with student_var1.
program(Program): The input distiller program.
Return(Variable): fsp distiller loss.
"""
teacher_var1 = program.global_block().var(teacher_var1_name)
teacher_var2 = program.global_block().var(teacher_var2_name)
student_var1 = program.global_block().var(student_var1_name)
student_var2 = program.global_block().var(student_var2_name)
teacher_fsp_matrix = fluid.layers.fsp_matrix(teacher_var1, teacher_var2)
student_fsp_matrix = fluid.layers.fsp_matrix(student_var1, student_var2)
fsp_loss = fluid.layers.reduce_mean(
fluid.layers.square(student_fsp_matrix - teacher_fsp_matrix))
return fsp_loss
def l2_loss(teacher_var_name, student_var_name, program):
"""
Combine variables from student model and teacher model by l2-loss.
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program.
Return(Variable): l2 distiller loss.
"""
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
l2_loss = fluid.layers.reduce_mean(
fluid.layers.square(student_var - teacher_var))
return l2_loss
def soft_label_loss(teacher_var_name,
student_var_name,
program,
teacher_temperature=1.,
student_temperature=1.):
"""
Combine variables from student model and teacher model by soft-label-loss.
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program.
teacher_temperature(float): Temperature used to divide
teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide
student_feature_map before softmax. default: 1.0
Return(Variable): l2 distiller loss.
"""
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
student_var = fluid.layers.softmax(student_var / student_temperature)
teacher_var = fluid.layers.softmax(teacher_var / teacher_temperature)
teacher_var.stop_gradient = True
soft_label_loss = fluid.layers.reduce_mean(
fluid.layers.cross_entropy(
student_var, teacher_var, soft_label=True))
return soft_label_loss
def loss(program, loss_func, **kwargs):
"""
Combine variables from student model and teacher model by self defined loss.
Args:
program(Program): The input distiller program.
loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss.
"""
func_parameters = {}
for item in kwargs.items():
if isinstance(item[1], str):
func_parameters.setdefault(item[0],
program.global_block().var(item[1]))
else:
func_parameters.setdefault(item[0], item[1])
loss = loss_func(**func_parameters)
return loss
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import socket import socket
import logging import logging
import numpy as np import numpy as np
import hashlib
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController from ..common import SAController
...@@ -33,98 +34,71 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -33,98 +34,71 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object): class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
max_flops=None, server_addr=("", 8881),
max_latency=None,
server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300, search_steps=300,
key="sa_nas", key="sa_nas",
is_server=True): is_server=False):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num). configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are `key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients. key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs self._configs = configs
self._keys = hashlib.md5(str(self._configs)).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
# create controller server
if self._is_server:
factory = SearchSpaceFactory() factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs) self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens() init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table() range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table) range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
print range_table controller = SAController(
range_table,
controller = SAController(range_table, self._reduce_rate, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature,
init_tokens, self._constrain_func) max_try_times=None,
init_tokens=init_tokens,
server_ip, server_port = server_addr constrain_func=None)
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip() max_client_num = 100
self._controller_server = ControllerServer( self._controller_server = ControllerServer(
controller=controller, controller=controller,
address=(server_ip, server_port), address=(server_ip, server_port),
max_client_num=max_client_num, max_client_num=max_client_num,
search_steps=search_steps, search_steps=search_steps,
key=key) key=self._key)
# create controller server
if self._is_server:
self._controller_server.start() self._controller_server.start()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
self._controller_server.ip(), server_ip, server_port, key=self._key)
self._controller_server.port(),
key=key)
self._iter = 0 self._iter = 0
def _get_host_ip(self): def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens):
if (self._max_flops is None) and (self._max_latency is None):
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self): def next_archs(self):
""" """
Get next network architectures. Get next network architectures.
...@@ -144,4 +118,5 @@ class SANAS(object): ...@@ -144,4 +118,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure. bool: True means updating successfully while false means failure.
""" """
self._iter += 1 self._iter += 1
return self._controller_client.update(self._current_tokens, score) return self._controller_client.update(self._current_tokens, score,
self._iter)
...@@ -19,9 +19,15 @@ import controller_server ...@@ -19,9 +19,15 @@ import controller_server
from controller_server import * from controller_server import *
import controller_client import controller_client
from controller_client import * from controller_client import *
import sensitive_pruner
from sensitive_pruner import *
import sensitive
from sensitive import *
__all__ = [] __all__ = []
__all__ += pruner.__all__ __all__ += pruner.__all__
__all__ += auto_pruner.__all__ __all__ += auto_pruner.__all__
__all__ += controller_server.__all__ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__
...@@ -42,7 +42,7 @@ class AutoPruner(object): ...@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=300, search_steps=300,
max_ratios=[0.9], max_ratios=[0.9],
...@@ -66,7 +66,7 @@ class AutoPruner(object): ...@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens. max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server. max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
...@@ -88,7 +88,7 @@ class AutoPruner(object): ...@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number self._max_try_times = max_try_times
self._is_server = is_server self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
...@@ -110,7 +110,7 @@ class AutoPruner(object): ...@@ -110,7 +110,7 @@ class AutoPruner(object):
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table)) _logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate, controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func) init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
...@@ -212,7 +212,7 @@ class AutoPruner(object): ...@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope) self._restore(self._scope)
self._param_backup = {} self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios) tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score) self._controller_client.update(tokens, score, self._iter)
self._iter += 1 self._iter += 1
def _restore(self, scope): def _restore(self, scope):
......
...@@ -17,6 +17,7 @@ import os ...@@ -17,6 +17,7 @@ import os
import logging import logging
import pickle import pickle
import numpy as np import numpy as np
import paddle.fluid as fluid
from ..core import GraphWrapper from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
from ..prune import Pruner from ..prune import Pruner
...@@ -27,13 +28,12 @@ __all__ = ["sensitivity"] ...@@ -27,13 +28,12 @@ __all__ = ["sensitivity"]
def sensitivity(program, def sensitivity(program,
scope,
place, place,
param_names, param_names,
eval_func, eval_func,
sensitivities_file=None, sensitivities_file=None,
step_size=0.2): step_size=0.2):
scope = fluid.global_scope()
graph = GraphWrapper(program) graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file) sensitivities = _load_sensitivities(sensitivities_file)
...@@ -55,7 +55,7 @@ def sensitivity(program, ...@@ -55,7 +55,7 @@ def sensitivity(program,
ratio += step_size ratio += step_size
continue continue
if baseline is None: if baseline is None:
baseline = eval_func(graph.program, scope) baseline = eval_func(graph.program)
param_backup = {} param_backup = {}
pruner = Pruner() pruner = Pruner()
...@@ -68,7 +68,7 @@ def sensitivity(program, ...@@ -68,7 +68,7 @@ def sensitivity(program,
lazy=True, lazy=True,
only_graph=False, only_graph=False,
param_backup=param_backup) param_backup=param_backup)
pruned_metric = eval_func(pruned_program, scope) pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio, _logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss)) loss))
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import copy
from scipy.optimize import leastsq
import numpy as np
import paddle.fluid as fluid
from ..common import get_logger
from .sensitive import sensitivity
from ..analysis import flops
from .pruner import Pruner
__all__ = ["SensitivePruner"]
_logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step.
Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute.
eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program.
scope(fluid.scope): The scope used to execute program.
"""
self._eval_func = eval_func
self._iter = 0
self._place = place
self._scope = fluid.global_scope() if scope is None else scope
self._pruner = Pruner()
self._checkpoints = checkpoints
def save_checkpoint(self, train_program, eval_program):
checkpoint = os.path.join(self._checkpoints, str(self._iter - 1))
exe = fluid.Executor(self._place)
fluid.io.save_persistables(
exe, checkpoint, main_program=train_program, filename="__params__")
with open(checkpoint + "/main_program", "wb") as f:
f.write(train_program.desc.serialize_to_string())
with open(checkpoint + "/eval_program", "wb") as f:
f.write(eval_program.desc.serialize_to_string())
def restore(self, checkpoints=None):
exe = fluid.Executor(self._place)
checkpoints = self._checkpoints if checkpoints is None else checkpoints
print("check points: {}".format(checkpoints))
main_program = None
eval_program = None
if checkpoints is not None:
cks = [dir for dir in os.listdir(checkpoints)]
if len(cks) > 0:
latest = max([int(ck) for ck in cks])
latest_ck_path = os.path.join(checkpoints, str(latest))
self._iter += 1
with open(latest_ck_path + "/main_program", "rb") as f:
program_desc_str = f.read()
main_program = fluid.Program.parse_from_string(
program_desc_str)
print main_program
with open(latest_ck_path + "/eval_program", "rb") as f:
program_desc_str = f.read()
eval_program = fluid.Program.parse_from_string(
program_desc_str)
with fluid.scope_guard(self._scope):
fluid.io.load_persistables(exe, latest_ck_path,
main_program, "__params__")
print("load checkpoint from: {}".format(latest_ck_path))
print("flops of eval program: {}".format(flops(eval_program)))
return main_program, eval_program, self._iter
def prune(self, train_program, eval_program, params, pruned_flops):
"""
Pruning parameters of training and evaluation network by sensitivities in current step.
Args:
train_program(fluid.Program): The training program to be pruned.
eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
params(list<str>): The parameters to be pruned.
pruned_flops(float): The ratio of FLOPS to be pruned in current step.
Return:
tuple: A tuple of pruned training program and pruned evaluation program.
"""
_logger.info("Pruning: {}".format(params))
sensitivities_file = "sensitivities_iter{}.data".format(self._iter)
with fluid.scope_guard(self._scope):
sensitivities = sensitivity(
eval_program,
self._place,
params,
self._eval_func,
sensitivities_file=sensitivities_file,
step_size=0.1)
print sensitivities
_, ratios = self._get_ratios_by_sensitive(sensitivities, pruned_flops,
eval_program)
pruned_program = self._pruner.prune(
train_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=False)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
eval_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=True)
self._iter += 1
return pruned_program, pruned_val_program
def _get_ratios_by_sensitive(self, sensitivities, pruned_flops,
eval_program):
"""
Search a group of ratios for pruning target flops.
"""
def func(params, x):
a, b, c, d = params
return a * x * x * x + b * x * x + c * x + d
def error(params, x, y):
return func(params, x) - y
def slove_coefficient(x, y):
init_coefficient = [10, 10, 10, 10]
coefficient, loss = leastsq(error, init_coefficient, args=(x, y))
return coefficient
min_loss = 0.
max_loss = 0.
# step 1: fit curve by sensitivities
coefficients = {}
for param in sensitivities:
losses = np.array([0] * 5 + sensitivities[param]['loss'])
precents = np.array([0] * 5 + sensitivities[param][
'pruned_percent'])
coefficients[param] = slove_coefficient(precents, losses)
loss = np.max(losses)
max_loss = np.max([max_loss, loss])
# step 2: Find a group of ratios by binary searching.
base_flops = flops(eval_program)
ratios = []
max_times = 20
while min_loss < max_loss and max_times > 0:
loss = (max_loss + min_loss) / 2
_logger.info(
'-----------Try pruned ratios while acc loss={}-----------'.
format(loss))
ratios = []
# step 2.1: Get ratios according to current loss
for param in sensitivities:
coefficient = copy.deepcopy(coefficients[param])
coefficient[-1] = coefficient[-1] - loss
roots = np.roots(coefficient)
for root in roots:
min_root = 1
if np.isreal(root) and root > 0 and root < 1:
selected_root = min(root.real, min_root)
ratios.append(selected_root)
_logger.info('Pruned ratios={}'.format(
[round(ratio, 3) for ratio in ratios]))
# step 2.2: Pruning by current ratios
param_shape_backup = {}
pruned_program = self._pruner.prune(
eval_program,
None, # scope
sensitivities.keys(),
ratios,
None, # place
only_graph=True)
pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
_logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))
# step 2.3: Check whether current ratios is enough
if abs(pruned_ratio - pruned_flops) < 0.015:
break
if pruned_ratio > pruned_flops:
max_loss = loss
else:
min_loss = loss
max_times -= 1
return sensitivities.keys(), ratios
...@@ -20,6 +20,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -20,6 +20,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
...@@ -186,19 +187,68 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -186,19 +187,68 @@ def quant_aware(program, place, config, scope=None, for_test=False):
return quant_program return quant_program
def quant_post(program, place, config, scope=None): def quant_post(executor,
model_dir,
quantize_model_path,
sample_generator,
model_filename=None,
params_filename=None,
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]):
""" """
add quantization ops in program. the program returned is not trainable. The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args: Args:
program(fluid.Program): program executor(fluid.Executor): The executor to load, run and save the
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). quantized model.
place(fluid.CPUPlace or fluid.CUDAPlace): place model_dir(str): The path of fp32 model that will be quantized, and
config(dict): configs for quantization, default values are in quant_config_default dict. the model and params that saved by fluid.io.save_inference_model
for_test: is for test program. are under the path.
Return: quantize_model_path(str): The path to save quantized model using api
fluid.Program: the quantization program is not trainable. fluid.io.save_inference_model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default is 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader, default is 16.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use fluid.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max method to get the scale factor. Default is 'KL'.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
Returns:
None
""" """
pass post_training_quantization = PostTrainingQuantization(
executor=executor,
sample_generator=sample_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=False)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, place, config, scope=None, save_int8=False): def convert(program, place, config, scope=None, save_int8=False):
......
...@@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase): ...@@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase):
base_flops = flops(main_program) base_flops = flops(main_program)
search_steps = 3 search_steps = 3
sa_nas = SANAS( sa_nas = SANAS(configs, search_steps=search_steps, is_server=True)
configs, max_flops=base_flops, search_steps=search_steps)
for i in range(search_steps): for i in range(search_steps):
archs = sa_nas.next_archs() archs = sa_nas.next_archs()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册