未验证 提交 5df65e50 编写于 作者: W whs 提交者: GitHub

Add Ligth-NAS for PaddleSlim (#17679)

* Add auto pruning strategy.
1. Fix compressor.
2. Enhence graph executor.
3. Add SAController
4. Add auto pruning strategy.
5. Add unitest for auto pruning strategy.
test=develop

* Init light-nas

* Add light nas.

* Some fix.
test=develop

* Fix sa controller.
test=develop

* Fix unitest of light nas.
test=develop

* Fix setup.py.in and API.spec.
test=develop

* Fix unitest.
1. Fix unitest on windows.
2. Fix package importing in tests directory.

* 1. Remove unused comments.
2. Expose eval_epoch option.
3. Remove unused function in search_agent.
4. Expose max_client_num to yaml file.
5. Move flops constraint to on_epoch_begin function
test=develop

* Fix light nas strategy.
test=develop

* Make controller server stable.
test=develop

* 1. Add try exception to compressor.
2. Remove unitest of light-nas for windows.
test=develop

* Add comments
Enhence controller
test=develop

* Fix comments.
test=develop
上级 3925bd81
......@@ -403,7 +403,7 @@ paddle.fluid.contrib.Calibrator.sample_data (ArgSpec(args=['self'], varargs=None
paddle.fluid.contrib.Calibrator.save_int8_model (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.contrib.distributed_sampler (ArgSpec(args=['reader', 'batch_size'], varargs=None, keywords=None, defaults=None), ('document', '9a271cd9700deb6d837ed724ba094315'))
paddle.fluid.contrib.reader.ctr_reader.ctr_reader (ArgSpec(args=['feed_dict', 'file_type', 'file_format', 'dense_slot_index', 'sparse_slot_index', 'capacity', 'thread_num', 'batch_size', 'file_list', 'slots', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b2ebf3de2a6ef1af2c3b88d2db7591ab'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], './checkpoints', None, None)), ('document', '31ae143830c9bf6b43547dd546c5ba80'))
paddle.fluid.contrib.Compressor.__init__ (ArgSpec(args=['self', 'place', 'scope', 'train_program', 'train_reader', 'train_feed_list', 'train_fetch_list', 'eval_program', 'eval_reader', 'eval_feed_list', 'eval_fetch_list', 'teacher_programs', 'checkpoint_path', 'train_optimizer', 'distiller_optimizer', 'search_space'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, [], None, None, None, None)), ('document', 'c195b3bba26169cff9439e8c467557c0'))
paddle.fluid.contrib.Compressor.config (ArgSpec(args=['self', 'config_file'], varargs=None, keywords=None, defaults=None), ('document', '780d9c007276ccbb95b292400d7807b0'))
paddle.fluid.contrib.Compressor.run (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', 'c6e43d6a078d307672283c1f36e04fe9'))
paddle.fluid.contrib.load_persistables_for_increment (ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None), ('document', '2ab36d4f7a564f5f65e455807ad06c67'))
......
......@@ -107,7 +107,6 @@ class CompiledProgram(object):
raise ValueError("Wrong program_to_graph type: %s" %
type(program_or_graph))
self._program_desc = self._graph.origin_program_desc()
self._scope = None
self._place = None
self._executor = None
......
......@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ....core import CPUPlace
from ....core import CPUPlace, EOFException
from .... import compiler
from ....framework import Variable
from .... import io
from .... import profiler
from .... import scope_guard
......@@ -28,6 +29,7 @@ import logging
import sys
import pickle
import functools
import traceback
__all__ = ['Context', 'Compressor']
......@@ -83,7 +85,8 @@ class Context(object):
eval_reader=None,
teacher_graphs=None,
train_optimizer=None,
distiller_optimizer=None):
distiller_optimizer=None,
search_space=None):
"""
Args:
place: The device place where the compression job running.
......@@ -119,6 +122,9 @@ class Context(object):
self.cache_path = './eval_cache'
self.eval_results = {}
self.skip_training = False
self.search_space = search_space
def to_file(self, file_name):
"""
Save the context into file.
......@@ -181,14 +187,30 @@ class Context(object):
if sampled_rate:
reader = cached_reader(reader, sampled_rate, self.cache_path,
cached_id)
for data in reader():
result = executor.run(eval_graph, self.scope, data=data)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 20 == 0:
_logger.info("batch-{}; {}={}".format(
batch_id, eval_graph.out_nodes.keys(), result))
batch_id += 1
if isinstance(reader, Variable):
reader.start()
try:
while True:
result = executor.run(eval_graph, self.scope)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 20 == 0:
_logger.info("batch-{}; {}={}".format(
batch_id, eval_graph.out_nodes.keys(), result))
batch_id += 1
except EOFException:
reader.reset()
else:
for data in reader():
result = executor.run(eval_graph, self.scope, data=data)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 20 == 0:
_logger.info("batch-{}; {}={}".format(
batch_id, eval_graph.out_nodes.keys(), result))
batch_id += 1
result = np.mean(np.array(results), axis=0)
_logger.info("Final eval result: {}={}".format(
eval_graph.out_nodes.keys(), result))
......@@ -221,9 +243,10 @@ class Compressor(object):
eval_feed_list=None,
eval_fetch_list=None,
teacher_programs=[],
checkpoint_path='./checkpoints',
checkpoint_path=None,
train_optimizer=None,
distiller_optimizer=None):
distiller_optimizer=None,
search_space=None):
"""
Args:
place(fluid.Place): The device place where the compression job running.
......@@ -251,12 +274,14 @@ class Compressor(object):
this optimizer is used to minimize the combined loss of student-net and
teacher-net while train_optimizer is used to minimize loss of
student-net in fine-tune stage.
search_space(slim.nas.SearchSpace): The instance that define the searching space. It must inherite
slim.nas.SearchSpace class and overwrite the abstract methods.
"""
assert isinstance(
assert train_feed_list is None or isinstance(
train_feed_list, list
), "train_feed_list should be a list of tuple, such as [('image', image.name), ('label', gt.name)]"
assert isinstance(
assert eval_feed_list is None or isinstance(
eval_feed_list, list
), "eval_feed_list should be a list of tuple, such as [('image', image.name), ('label', gt.name)]"
self.strategies = []
......@@ -281,6 +306,8 @@ class Compressor(object):
self.distiller_optimizer = distiller_optimizer
self.init_model = None
self.search_space = search_space
def _add_strategy(self, strategy):
"""
Add a strategy to current compress pass.
......@@ -306,6 +333,9 @@ class Compressor(object):
if 'init_model' in factory.compressor:
self.init_model = factory.compressor['init_model']
if 'eval_epoch' in factory.compressor:
self.eval_epoch = factory.compressor['eval_epoch']
def _init_model(self, context):
"""
Load model that has been compressed.
......@@ -402,7 +432,8 @@ class Compressor(object):
"""
Train one epoch.
"""
if context.skip_training:
return
executor = SlimGraphExecutor(self.place)
if context.optimize_graph.compiled_graph is None:
......@@ -410,21 +441,44 @@ class Compressor(object):
context.optimize_graph.program).with_data_parallel(
loss_name=context.optimize_graph.out_nodes['loss'])
for data in context.train_reader():
for strategy in self.strategies:
strategy.on_batch_begin(context)
results = executor.run(context.optimize_graph,
context.scope,
data=data)
results = [float(np.mean(result)) for result in results]
if context.batch_id % 20 == 0:
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
context.epoch_id, context.batch_id,
context.optimize_graph.out_nodes.keys(
), [round(r, 3) for r in results]))
for strategy in self.strategies:
strategy.on_batch_end(context)
context.batch_id += 1
if isinstance(context.train_reader, Variable):
context.train_reader.start()
try:
while True:
for strategy in self.strategies:
strategy.on_batch_begin(context)
results = executor.run(context.optimize_graph,
context.scope)
results = [float(np.mean(result)) for result in results]
if context.batch_id % 20 == 0:
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
context.epoch_id, context.batch_id,
context.optimize_graph.out_nodes.keys(
), [round(r, 3) for r in results]))
for strategy in self.strategies:
strategy.on_batch_end(context)
context.batch_id += 1
except EOFException:
context.train_reader.reset()
else:
for data in context.train_reader():
for strategy in self.strategies:
strategy.on_batch_begin(context)
results = executor.run(context.optimize_graph,
context.scope,
data=data)
results = [float(np.mean(result)) for result in results]
if context.batch_id % 20 == 0:
_logger.info("epoch:{}; batch_id:{}; {} = {}".format(
context.epoch_id, context.batch_id,
context.optimize_graph.out_nodes.keys(
), [round(r, 3) for r in results]))
for strategy in self.strategies:
strategy.on_batch_end(context)
context.batch_id += 1
context.batch_id = 0
def _eval(self, context):
......@@ -450,7 +504,8 @@ class Compressor(object):
eval_reader=self.eval_reader,
teacher_graphs=self.teacher_graphs,
train_optimizer=self.train_optimizer,
distiller_optimizer=self.distiller_optimizer)
distiller_optimizer=self.distiller_optimizer,
search_space=self.search_space)
self.context = context
if self.teacher_graphs:
context.put('teachers', self.teacher_graphs)
......@@ -472,17 +527,20 @@ class Compressor(object):
]:
return None
start = context.epoch_id
self._eval(context)
for epoch in range(start, self.epoch):
context.epoch_id = epoch
for strategy in self.strategies:
strategy.on_epoch_begin(context)
self._train_one_epoch(context)
for strategy in self.strategies:
strategy.on_epoch_end(context)
if self.eval_epoch and epoch % self.eval_epoch == 0:
self._eval(context)
self._save_checkpoint(context)
try:
for strategy in self.strategies:
strategy.on_epoch_begin(context)
self._train_one_epoch(context)
if self.eval_epoch and epoch % self.eval_epoch == 0:
self._eval(context)
self._save_checkpoint(context)
for strategy in self.strategies:
strategy.on_epoch_end(context)
except Exception:
_logger.error(traceback.print_exc())
continue
for strategy in self.strategies:
strategy.on_compression_end(context)
return context.eval_graph
......@@ -20,11 +20,15 @@ from ..prune import *
from ..quantization import *
from .strategy import *
from ..distillation import *
from ..searcher import *
from ..nas import *
__all__ = ['ConfigFactory']
"""This factory is used to create instances by loading and parsing configure file with yaml format.
"""
PLUGINS = ['pruners', 'quantizers', 'quantizers', 'strategies', 'controllers']
class ConfigFactory(object):
def __init__(self, config):
......@@ -80,7 +84,7 @@ class ConfigFactory(object):
assert self.version == int(key_values['version'])
# parse pruners
if key == 'distillers' or key == 'pruners' or key == 'quantizers' or key == 'strategies':
if key in PLUGINS:
instances = key_values[key]
for name in instances:
self._new_instance(name, instances[name])
......@@ -91,8 +95,12 @@ class ConfigFactory(object):
if 'init_model' in key_values[key]:
self.compressor['init_model'] = key_values[key][
'init_model']
self.compressor['checkpoint_path'] = key_values[key][
'checkpoint_path']
if 'checkpoint_path' in key_values[key]:
self.compressor['checkpoint_path'] = key_values[key][
'checkpoint_path']
if 'eval_epoch' in key_values[key]:
self.compressor['eval_epoch'] = key_values[key][
'eval_epoch']
if 'strategies' in key_values[key]:
for name in key_values[key]['strategies']:
strategy = self.instance(name)
......
......@@ -41,6 +41,7 @@ class SlimGraphExecutor(object):
results(list): A list of result with the same order indicated by graph.out_nodes.
"""
assert isinstance(graph, GraphWrapper)
feed = None
if data is not None:
feeder = DataFeeder(
feed_list=graph.in_nodes.values(),
......
......@@ -209,6 +209,7 @@ class GraphWrapper(object):
if var.persistable:
self.persistables[var.name] = var
self.compiled_graph = None
in_nodes = [] if in_nodes is None else in_nodes
self.in_nodes = OrderedDict(in_nodes)
self.out_nodes = OrderedDict(out_nodes)
self._attrs = OrderedDict()
......@@ -241,7 +242,7 @@ class GraphWrapper(object):
"""
return var._var.persistable
def compile(self, for_parallel=True, for_test=False):
def compile(self, for_parallel=True, for_test=False, mem_opt=False):
"""
Compile the program in this wrapper to framework.CompiledProgram for next running.
This function must be called if the program is modified.
......@@ -257,8 +258,9 @@ class GraphWrapper(object):
if for_parallel:
# disable memory optimize for stable training
build_strategy = compiler.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
build_strategy.enable_inplace = mem_opt
build_strategy.memory_optimize = mem_opt
# build_strategy.async_mode = False
self.compiled_graph = compiler.CompiledProgram(
target).with_data_parallel(
loss_name=loss, build_strategy=build_strategy)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import light_nas_strategy
from .light_nas_strategy import *
from . import controller_server
from .controller_server import *
from . import search_agent
from .search_agent import *
from . import search_space
from .search_space import *
from . import lock
from .lock import *
__all__ = light_nas_strategy.__all__
__all__ += controller_server.__all__
__all__ += search_agent.__all__
__all__ += search_space.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
from threading import Thread
__all__ = ['ControllerServer']
logging.basicConfig(
format='ControllerServer-%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class ControllerServer(object):
"""
The controller wrapper with a socket server to handle the request of search agentt.
"""
def __init__(self,
controller=None,
address=('', 0),
max_client_num=100,
search_steps=None,
key=None):
"""
Args:
controller(slim.searcher.Controller): The controller used to generate tokens.
address(tuple): The address of current server binding with format (ip, port). Default: ('', 0).
which means setting ip automatically
max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100.
search_steps(int): The total steps of searching. None means never stopping. Default: None
"""
self._controller = controller
self._address = address
self._max_client_num = max_client_num
self._search_steps = search_steps
self._closed = False
self._port = address[1]
self._ip = address[0]
self._key = key
def start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num)
self._port = self._socket_server.getsockname()[1]
self._ip = self._socket_server.getsockname()[0]
_logger.info("listen on: [{}:{}]".format(self._ip, self._port))
thread = Thread(target=self.run)
thread.start()
return str(thread)
def close(self):
"""Close the server."""
self._closed = True
def port(self):
"""Get the port."""
return self._port
def ip(self):
"""Get the ip."""
return self._ip
def run(self):
_logger.info("Controller Server run...")
while ((self._search_steps is None) or
(self._controller._iter <
(self._search_steps))) and not self._closed:
conn, addr = self._socket_server.accept()
message = conn.recv(1024).decode()
if message.strip("\n") == "next_tokens":
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
else:
_logger.info("recv message from {}: [{}]".format(addr, message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
_logger.info("recv noise from {}: [{}]".format(addr,
message))
continue
tokens = messages[1]
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
_logger.info("send message to {}: [{}]".format(addr, tokens))
conn.close()
self._socket_server.close()
_logger.info("server closed!")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core.strategy import Strategy
from ..graph import GraphWrapper
from .controller_server import ControllerServer
from .search_agent import SearchAgent
from ....executor import Executor
import re
import logging
import functools
import socket
from .lock import lock, unlock
__all__ = ['LightNASStrategy']
logging.basicConfig(
format='LightNASStrategy-%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class LightNASStrategy(Strategy):
"""
Light-NAS search strategy.
"""
def __init__(self,
controller=None,
end_epoch=1000,
target_flops=629145600,
retrain_epoch=1,
metric_name='top1_acc',
server_ip=None,
server_port=0,
is_server=False,
max_client_num=100,
search_steps=None,
key="light-nas"):
"""
Args:
controller(searcher.Controller): The searching controller. Default: None.
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
target_flops(int): The constraint of FLOPS.
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
metric_name(str): The metric used to evaluate the model.
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
is_server(bool): Whether current host is controller server. Default: False.
max_client_num(int): The maximum number of clients that connect to controller server concurrently. Default: 100.
search_steps(int): The total steps of searching. Default: None.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.start_epoch = 0
self.end_epoch = end_epoch
self._max_flops = target_flops
self._metric_name = metric_name
self._controller = controller
self._retrain_epoch = 0
self._server_ip = server_ip
self._server_port = server_port
self._is_server = is_server
self._retrain_epoch = retrain_epoch
self._search_steps = search_steps
self._max_client_num = max_client_num
self._max_try_times = 100
self._key = key
if self._server_ip is None:
self._server_ip = self._get_host_ip()
def _get_host_ip(self):
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
def on_compression_begin(self, context):
self._current_tokens = context.search_space.init_tokens()
constrain_func = functools.partial(
self._constrain_func, context=context)
self._controller.reset(context.search_space.range_table(),
self._current_tokens, None)
# create controller server
if self._is_server:
open("./slim_LightNASStrategy_controller_server.socket",
'a').close()
socket_file = open(
"./slim_LightNASStrategy_controller_server.socket", 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
self._server = ControllerServer(
controller=self._controller,
address=(self._server_ip, self._server_port),
max_client_num=self._max_client_num,
search_steps=self._search_steps,
key=self._key)
tid = self._server.start()
self._server_port = self._server.port()
socket_file.write(tid)
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
_logger.info("self._server_ip: {}; self._server_port: {}".format(
self._server_ip, self._server_port))
# create client
self._search_agent = SearchAgent(
self._server_ip, self._server_port, key=self._key)
def _constrain_func(self, tokens, context=None):
"""Check whether the tokens meet constraint."""
_, _, test_prog, _, _, _, _ = context.search_space.create_net(tokens)
flops = GraphWrapper(test_prog).flops()
if flops <= self._max_flops:
return True
else:
return False
def on_epoch_begin(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
_logger.info("light nas strategy on_epoch_begin")
for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens)
_logger.info("try [{}]".format(self._current_tokens))
context.eval_graph.program = test_p
flops = context.eval_graph.flops()
if flops <= self._max_flops:
break
else:
self._current_tokens = self._search_agent.next_tokens()
context.train_reader = train_reader
context.eval_reader = test_reader
exe = Executor(context.place)
exe.run(startup_p)
context.optimize_graph.program = train_p
context.optimize_graph.compile()
context.skip_training = (self._retrain_epoch == 0)
def on_epoch_end(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch + 1
) % self._retrain_epoch == 0):
self._current_reward = context.eval_results[self._metric_name][-1]
flops = context.eval_graph.flops()
if flops > self._max_flops:
self._current_reward = 0.0
_logger.info("reward: {}; flops: {}; tokens: {}".format(
self._current_reward, flops, self._current_tokens))
self._current_tokens = self._search_agent.update(
self._current_tokens, self._current_reward)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
__All__ = ['lock', 'unlock']
if os.name == 'nt':
def lock(file):
raise NotImplementedError('Windows is not supported.')
def unlock(file):
raise NotImplementedError('Windows is not supported.')
elif os.name == 'posix':
from fcntl import flock, LOCK_EX, LOCK_UN
def lock(file):
"""Lock the file in local file system."""
flock(file.fileno(), LOCK_EX)
def unlock(file):
"""Unlock the file in local file system."""
flock(file.fileno(), LOCK_UN)
else:
raise RuntimeError("File Locker only support NT and Posix platforms!")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import socket
__all__ = ['SearchAgent']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class SearchAgent(object):
"""
Search agent.
"""
def __init__(self, server_ip=None, server_port=None, key=None):
"""
Args:
server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None.
server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0.
key(str): The key used to identify legal agent for controller server. Default: "light-nas"
"""
self.server_ip = server_ip
self.server_port = server_port
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
def next_tokens(self):
"""
Get next tokens.
"""
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
socket_client.send("next_tokens".encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The search space used to search neural architecture"""
__all__ = ['SearchSpace']
class SearchSpace(object):
"""Controller for Neural Architecture Search.
"""
def __init__(self, *args, **kwargs):
pass
def init_tokens(self):
"""Get init tokens in search space.
"""
raise NotImplementedError('Abstract method.')
def range_table(self):
"""Get range table of current search space.
"""
raise NotImplementedError('Abstract method.')
def create_net(self, tokens):
"""Create networks for training and evaluation according to tokens.
Args:
tokens(list<int>): The tokens which represent a network.
Return:
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
"""
raise NotImplementedError('Abstract method.')
......@@ -16,6 +16,9 @@ from . import pruner
from .pruner import *
from . import prune_strategy
from .prune_strategy import *
from . import auto_prune_strategy
from .auto_prune_strategy import *
__all__ = pruner.__all__
__all__ += prune_strategy.__all__
__all__ += auto_prune_strategy.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .prune_strategy import PruneStrategy
import re
import logging
import functools
import copy
__all__ = ['AutoPruneStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class AutoPruneStrategy(PruneStrategy):
"""
Automatic pruning strategy.
"""
def __init__(self,
pruner=None,
controller=None,
start_epoch=0,
end_epoch=10,
min_ratio=0.5,
max_ratio=0.7,
metric_name='top1_acc',
pruned_params='conv.*_weights',
retrain_epoch=0):
"""
Args:
pruner(slim.Pruner): The pruner used to prune the parameters. Default: None.
controller(searcher.Controller): The searching controller. Default: None.
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. Default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
min_ratio(float): The maximum pruned ratio. Default: 0.7
max_ratio(float): The minimum pruned ratio. Default: 0.5
metric_name(str): The metric used to evaluate the model.
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
pruned_params(str): The pattern str to match the parameter names to be pruned. Default: 'conv.*_weights'
retrain_epoch(int): The training epochs in each seaching step. Default: 0
"""
super(AutoPruneStrategy, self).__init__(pruner, start_epoch, end_epoch,
0.0, metric_name, pruned_params)
self._max_ratio = max_ratio
self._min_ratio = min_ratio
self._controller = controller
self._metric_name = metric_name
self._pruned_param_names = []
self._retrain_epoch = 0
self._current_tokens = None
def on_compression_begin(self, context):
"""
Prepare some information for searching strategy.
step 1: Find all the parameters to be pruned.
step 2: Get initial tokens and setup controller.
"""
pruned_params = []
for param in context.eval_graph.all_parameters():
if re.match(self.pruned_params, param.name()):
self._pruned_param_names.append(param.name())
self._current_tokens = self._get_init_tokens(context)
self._range_table = copy.deepcopy(self._current_tokens)
constrain_func = functools.partial(
self._constrain_func, context=context)
self._controller.reset(self._range_table, self._current_tokens,
constrain_func)
def _constrain_func(self, tokens, context=None):
"""Check whether the tokens meet constraint."""
ori_flops = context.eval_graph.flops()
ratios = self._tokens_to_ratios(tokens)
params = self._pruned_param_names
param_shape_backup = {}
self._prune_parameters(
context.eval_graph,
context.scope,
params,
ratios,
context.place,
only_graph=True,
param_shape_backup=param_shape_backup)
context.eval_graph.update_groups_of_conv()
flops = context.eval_graph.flops()
for param in param_shape_backup.keys():
context.eval_graph.var(param).set_shape(param_shape_backup[param])
flops_ratio = (1 - float(flops) / ori_flops)
if flops_ratio >= self._min_ratio and flops_ratio <= self._max_ratio:
return True
else:
return False
def _get_init_tokens(self, context):
"""Get initial tokens.
"""
ratios = self._get_uniform_ratios(context)
return self._ratios_to_tokens(ratios)
def _ratios_to_tokens(self, ratios):
"""Convert pruned ratios to tokens.
"""
return [int(ratio / 0.01) for ratio in ratios]
def _tokens_to_ratios(self, tokens):
"""Convert tokens to pruned ratios.
"""
return [token * 0.01 for token in tokens]
def _get_uniform_ratios(self, context):
"""
Search a group of uniform ratios.
"""
min_ratio = 0.
max_ratio = 1.
target = (self._min_ratio + self._max_ratio) / 2
flops = context.eval_graph.flops()
model_size = context.eval_graph.numel_params()
ratios = None
while min_ratio < max_ratio:
ratio = (max_ratio + min_ratio) / 2
ratios = [ratio] * len(self._pruned_param_names)
param_shape_backup = {}
self._prune_parameters(
context.eval_graph,
context.scope,
self._pruned_param_names,
ratios,
context.place,
only_graph=True,
param_shape_backup=param_shape_backup)
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
model_size)
for param in param_shape_backup.keys():
context.eval_graph.var(param).set_shape(param_shape_backup[
param])
if abs(pruned_flops - target) < 1e-2:
break
if pruned_flops > target:
max_ratio = ratio
else:
min_ratio = ratio
_logger.info('Get ratios: {}'.format([round(r, 2) for r in ratios]))
return ratios
def on_epoch_begin(self, context):
"""
step 1: Get a new tokens from controller.
step 2: Pruning eval_graph and optimize_program by tokens
"""
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
self._current_tokens = self._controller.next_tokens()
params = self._pruned_param_names
ratios = self._tokens_to_ratios(self._current_tokens)
self._param_shape_backup = {}
self._param_backup = {}
self._prune_parameters(
context.optimize_graph,
context.scope,
params,
ratios,
context.place,
param_backup=self._param_backup,
param_shape_backup=self._param_shape_backup)
self._prune_graph(context.eval_graph, context.optimize_graph)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
context.optimize_graph.compile(
mem_opt=True) # to update the compiled program
context.skip_training = (self._retrain_epoch == 0)
def on_epoch_end(self, context):
"""
step 1: Get reward of current tokens and update controller.
step 2: Restore eval_graph and optimize_graph
"""
if context.epoch_id >= self.start_epoch and context.epoch_id < self.end_epoch and (
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
reward = context.eval_results[self._metric_name][-1]
self._controller.update(self._current_tokens, reward)
# restore pruned parameters
for param_name in self._param_backup.keys():
param_t = context.scope.find_var(param_name).get_tensor()
param_t.set(self._param_backup[param_name], context.place)
self._param_backup = {}
# restore shape of parameters
for param in self._param_shape_backup.keys():
context.optimize_graph.var(param).set_shape(
self._param_shape_backup[param])
self._param_shape_backup = {}
self._prune_graph(context.eval_graph, context.optimize_graph)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
context.optimize_graph.compile(
mem_opt=True) # to update the compiled program
elif context.epoch_id == self.end_epoch: # restore graph for final training
# restore pruned parameters
for param_name in self._param_backup.keys():
param_t = context.scope.find_var(param_name).get_tensor()
param_t.set(self.param_backup[param_name], context.place)
# restore shape of parameters
for param in self._param_shape_backup.keys():
context.eval_graph.var(param).set_shape(
self._param_shape_backup[param])
context.optimize_graph.var(param).set_shape(
self._param_shape_backup[param])
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
params, ratios = self._get_prune_ratios(
self._controller._best_tokens)
self._prune_parameters(context.optimize_graph, context.scope,
params, ratios, context.place)
self._prune_graph(context.eval_graph, context.optimize_graph)
context.optimize_graph.update_groups_of_conv()
context.eval_graph.update_groups_of_conv()
context.optimize_graph.compile(
mem_opt=True) # to update the compiled program
context.skip_training = False
......@@ -26,7 +26,7 @@ import pickle
import logging
import sys
__all__ = ['SensitivePruneStrategy', 'UniformPruneStrategy']
__all__ = ['SensitivePruneStrategy', 'UniformPruneStrategy', 'PruneStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
......@@ -61,8 +61,6 @@ class PruneStrategy(Strategy):
self.metric_name = metric_name
self.pruned_params = pruned_params
self.pruned_list = []
self.backup = {}
self.param_shape_backup = {}
def _eval_graph(self, context, sampled_rate=None, cached_id=0):
"""
......@@ -82,7 +80,9 @@ class PruneStrategy(Strategy):
ratio,
place,
lazy=False,
only_graph=False):
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning filters by given ratio.
Args:
......@@ -103,16 +103,16 @@ class PruneStrategy(Strategy):
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if lazy:
self.backup[param.name()] = copy.deepcopy(np.array(param_t))
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
pruned_param = self.pruner.prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
if not only_graph:
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param.name() not in self.param_shape_backup:
self.param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
new_shape = list(param.shape())
new_shape[0] = pruned_param.shape[0]
param.set_shape(new_shape)
......@@ -120,7 +120,8 @@ class PruneStrategy(Strategy):
'|----------------------------------------+----+------------------------------+------------------------------|'
)
_logger.debug('|{:^40}|{:^4}|{:^30}|{:^30}|'.format(
str(param.name()), str(0), str(ori_shape), str(param.shape())))
str(param.name()),
str(ratio), str(ori_shape), str(param.shape())))
self.pruned_list[0].append(param.name())
return pruned_idx
......@@ -131,7 +132,9 @@ class PruneStrategy(Strategy):
pruned_axis,
place,
lazy=False,
only_graph=False):
only_graph=False,
param_shape_backup=None,
param_backup=None):
"""
Pruning parameters in given axis.
Args:
......@@ -150,16 +153,17 @@ class PruneStrategy(Strategy):
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if lazy:
self.backup[param.name()] = copy.deepcopy(np.array(param_t))
if param_backup is not None and (param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t))
pruned_param = self.pruner.prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
if not only_graph:
param_t.set(pruned_param, place)
ori_shape = param.shape()
if param.name() not in self.param_shape_backup:
self.param_shape_backup[param.name()] = copy.deepcopy(
param.shape())
if param_shape_backup is not None and (
param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape())
new_shape = list(param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
param.set_shape(new_shape)
......@@ -251,7 +255,9 @@ class PruneStrategy(Strategy):
ratio=None,
pruned_idxs=None,
lazy=False,
only_graph=False):
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning all the parameters affected by the pruning of given parameter.
Args:
......@@ -284,7 +290,9 @@ class PruneStrategy(Strategy):
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
else:
pruned_idxs = self._prune_filters_by_ratio(
......@@ -292,7 +300,9 @@ class PruneStrategy(Strategy):
ratio,
place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
corrected_idxs = pruned_idxs[:]
for idx, op in enumerate(related_ops):
......@@ -307,7 +317,9 @@ class PruneStrategy(Strategy):
pruned_axis=1,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
if op.type() == "depthwise_conv2d":
for in_var in op.all_inputs():
if graph.is_parameter(in_var):
......@@ -319,7 +331,9 @@ class PruneStrategy(Strategy):
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "elementwise_add":
# pruning bias
for in_var in op.all_inputs():
......@@ -332,7 +346,9 @@ class PruneStrategy(Strategy):
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "mul": # pruning fc layer
fc_input = None
fc_param = None
......@@ -354,7 +370,9 @@ class PruneStrategy(Strategy):
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
elif op.type() == "concat":
concat_inputs = op.all_inputs()
......@@ -378,28 +396,36 @@ class PruneStrategy(Strategy):
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [variance] + self._get_accumulator(graph, variance),
corrected_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [alpha] + self._get_accumulator(graph, alpha),
corrected_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
self._prune_parameter_by_idx(
scope, [beta] + self._get_accumulator(graph, beta),
corrected_idxs,
pruned_axis=0,
place=place,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
def _prune_parameters(self,
graph,
......@@ -408,7 +434,9 @@ class PruneStrategy(Strategy):
ratios,
place,
lazy=False,
only_graph=False):
only_graph=False,
param_backup=None,
param_shape_backup=None):
"""
Pruning the given parameters.
Args:
......@@ -444,7 +472,9 @@ class PruneStrategy(Strategy):
place,
ratio=ratio,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
ops = param.outputs()
for op in ops:
if op.type() == 'conv2d':
......@@ -458,7 +488,9 @@ class PruneStrategy(Strategy):
place,
ratio=ratio,
lazy=lazy,
only_graph=only_graph)
only_graph=only_graph,
param_backup=param_backup,
param_shape_backup=param_shape_backup)
_logger.debug(
'|----------------------------------------+----+------------------------------+------------------------------|'
)
......@@ -575,23 +607,24 @@ class UniformPruneStrategy(PruneStrategy):
_logger.debug(
'-----------Try pruning ratio: {:.2f}-----------'.format(ratio))
ratios = [ratio] * len(pruned_params)
param_shape_backup = {}
self._prune_parameters(
context.eval_graph,
context.scope,
pruned_params,
ratios,
context.place,
only_graph=True)
only_graph=True,
param_shape_backup=param_shape_backup)
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
model_size)
_logger.debug('Pruned flops: {:.2f}'.format(pruned_flops))
_logger.debug('Pruned model size: {:.2f}'.format(pruned_size))
for param in self.param_shape_backup.keys():
context.eval_graph.var(param).set_shape(self.param_shape_backup[
for param in param_shape_backup.keys():
context.eval_graph.var(param).set_shape(param_shape_backup[
param])
self.param_shape_backup = {}
if abs(pruned_flops - self.target_ratio) < 1e-2:
break
......@@ -672,8 +705,6 @@ class SensitivePruneStrategy(PruneStrategy):
self.pruned_list = []
self.sensitivities = sensitivities
self.sensitivities_file = sensitivities_file
self.backup = {}
self.param_shape_backup = {}
self.num_steps = num_steps
self.eval_rate = eval_rate
self.pruning_step = 1 - pow((1 - target_ratio), 1.0 / self.num_steps)
......@@ -728,8 +759,6 @@ class SensitivePruneStrategy(PruneStrategy):
Computing the sensitivities of all parameters.
"""
_logger.info("calling _compute_sensitivities.")
self.param_shape_backup = {}
self.backup = {}
cached_id = np.random.randint(1000)
if self.start_epoch == context.epoch_id:
sensitivities_file = self.sensitivities_file
......@@ -761,12 +790,15 @@ class SensitivePruneStrategy(PruneStrategy):
if metric is None:
metric = self._eval_graph(context, self.eval_rate,
cached_id)
param_backup = {}
# prune parameter by ratio
self._prune_parameters(
context.eval_graph,
context.scope, [param], [ratio],
context.place,
lazy=True)
lazy=True,
param_backup=param_backup)
self.pruned_list[0]
# get accuracy after pruning and update self.sensitivities
pruned_metric = self._eval_graph(context, self.eval_rate,
......@@ -787,12 +819,11 @@ class SensitivePruneStrategy(PruneStrategy):
self._save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in self.backup.keys():
for param_name in param_backup.keys():
param_t = context.scope.find_var(param_name).get_tensor()
param_t.set(self.backup[param_name], context.place)
param_t.set(self.param_backup[param_name], context.place)
# pruned_metric = self._eval_graph(context)
self.backup = {}
ratio += self.delta_rate
return sensitivities
......@@ -803,8 +834,6 @@ class SensitivePruneStrategy(PruneStrategy):
"""
_logger.info('_get_best_ratios for pruning ratie: {}'.format(
target_ratio))
self.param_shape_backup = {}
self.backup = {}
def func(params, x):
a, b, c, d = params
......@@ -854,23 +883,24 @@ class SensitivePruneStrategy(PruneStrategy):
_logger.info('Pruned ratios={}'.format(
[round(ratio, 3) for ratio in ratios]))
# step 2.2: Pruning by current ratios
param_shape_backup = {}
self._prune_parameters(
context.eval_graph,
context.scope,
sensitivities.keys(),
ratios,
context.place,
only_graph=True)
only_graph=True,
param_shape_backup=param_shape_backup)
pruned_flops = 1 - (float(context.eval_graph.flops()) / flops)
pruned_size = 1 - (float(context.eval_graph.numel_params()) /
model_size)
_logger.info('Pruned flops: {:.4f}'.format(pruned_flops))
_logger.info('Pruned model size: {:.4f}'.format(pruned_size))
for param in self.param_shape_backup.keys():
context.eval_graph.var(param).set_shape(self.param_shape_backup[
for param in param_shape_backup.keys():
context.eval_graph.var(param).set_shape(param_shape_backup[
param])
self.param_shape_backup = {}
# step 2.3: Check whether current ratios is enough
if abs(pruned_flops - target_ratio) < 0.015:
......@@ -902,9 +932,6 @@ class SensitivePruneStrategy(PruneStrategy):
self._prune_parameters(context.optimize_graph, context.scope,
params, ratios, context.place)
self.param_shape_backup = {}
self.backup = {}
model_size = context.eval_graph.numel_params()
flops = context.eval_graph.flops()
_logger.debug('################################')
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import controller
from .controller import *
__all__ = controller.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The controller used to search hyperparameters or neural architecture"""
import numpy as np
import copy
import math
import logging
__all__ = ['EvolutionaryController', 'SAController']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class EvolutionaryController(object):
"""Abstract controller for all evolutionary searching method.
"""
def __init__(self, *args, **kwargs):
pass
def update(self, tokens, reward):
"""Update the status of controller according current tokens and reward.
Args:
tokens(list<int>): A solution of searching task.
reward(list<int>): The reward of tokens.
"""
raise NotImplementedError('Abstract method.')
def reset(self, range_table, constrain_func=None):
"""Reset the controller.
Args:
range_table(list<int>): It is used to define the searching space of controller.
The tokens[i] generated by controller should be in [0, range_table[i]).
constrain_func(function): It is used to check whether tokens meet the constraint.
None means there is no constraint. Default: None.
"""
raise NotImplementedError('Abstract method.')
def next_tokens(self):
"""Generate new tokens.
"""
raise NotImplementedError('Abstract method.')
class SAController(EvolutionaryController):
"""Simulated annealing controller."""
def __init__(self,
range_table=None,
reduce_rate=0.85,
init_temperature=1024,
max_iter_number=300):
"""Initialize.
Args:
range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature.
max_iter_number(int): max iteration number.
"""
super(SAController, self).__init__()
self._range_table = range_table
self._reduce_rate = reduce_rate
self._init_temperature = init_temperature
self._max_iter_number = max_iter_number
self._reward = -1
self._tokens = None
self._max_reward = -1
self._best_tokens = None
self._iter = 0
def __getstate__(self):
d = {}
for key in self.__dict__:
if key != "_constrain_func":
d[key] = self.__dict__[key]
return d
def reset(self, range_table, init_tokens, constrain_func=None):
"""
Reset the status of current controller.
Args:
range_table(list<int>): The range of value in each position of tokens generated by current controller. The range of tokens[i] is [0, range_table[i]).
init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
"""
self._range_table = range_table
self._constrain_func = constrain_func
self._tokens = init_tokens
self._iter = 0
def update(self, tokens, reward):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
self._tokens = tokens
if reward > self._max_reward:
self._max_reward = reward
self._best_tokens = tokens
_logger.info("iter: {}; max_reward: {}; best_tokens: {}".format(
self._iter, self._max_reward, self._best_tokens))
_logger.info("current_reward: {}; current tokens: {}".format(
self._reward, self._tokens))
def next_tokens(self):
"""
Get next tokens.
"""
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table) * np.random.random())
new_tokens[index] = (
new_tokens[index] + np.random.randint(self._range_table[index] - 1)
+ 1) % self._range_table[index]
_logger.info("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index]))
if self._constrain_func is None:
return new_tokens
for _ in range(self._max_iter_number):
if not self._constrain_func(new_tokens):
index = int(len(self._range_table) * np.random.random())
new_tokens = tokens[:]
new_tokens[index] = np.random.randint(self._range_table[index])
else:
break
return new_tokens
......@@ -16,6 +16,10 @@ endfunction()
# Need to figure out the root cause and then add it back
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
endif()
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
......
version: 1.0
pruners:
pruner_1:
class: 'StructurePruner'
pruning_axis:
'*': 0
criterions:
'*': 'l1_norm'
controllers:
sa_controller:
class: 'SAController'
reduce_rate: 0.9
init_temperature: 1024
max_iter_number: 300
strategies:
auto_pruning_strategy:
class: 'AutoPruneStrategy'
pruner: 'pruner_1'
controller: 'sa_controller'
start_epoch: 0
end_epoch: 2
max_ratio: 0.7
min_ratio: 0.5
pruned_params: '.*_sep_weights'
metric_name: 'acc_top5'
compressor:
epoch: 2
checkpoint_path: './checkpoints_auto_pruning/'
strategies:
- auto_pruning_strategy
version: 1.0
controllers:
sa_controller:
class: 'SAController'
reduce_rate: 0.9
init_temperature: 1024
max_iter_number: 300
strategies:
light_nas_strategy:
class: 'LightNASStrategy'
controller: 'sa_controller'
target_flops: 629145600
end_epoch: 2
retrain_epoch: 1
metric_name: 'acc_top1'
is_server: 1
max_client_num: 100
search_steps: 2
compressor:
epoch: 2
strategies:
- light_nas_strategy
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.contrib.slim.nas import SearchSpace
from light_nasnet import LightNASNet
import paddle.fluid as fluid
import paddle
import json
total_images = 1281167
lr = 0.1
num_epochs = 1
batch_size = 256
lr_strategy = "cosine_decay"
l2_decay = 4e-5
momentum_rate = 0.9
image_shape = [1, 28, 28]
__all__ = ['LightNASSpace']
NAS_FILTER_SIZE = [[18, 24, 30], [24, 32, 40], [48, 64, 80], [72, 96, 120],
[120, 160, 192]]
NAS_LAYERS_NUMBER = [[1, 2, 3], [2, 3, 4], [3, 4, 5], [2, 3, 4], [2, 3, 4]]
NAS_KERNEL_SIZE = [3, 5]
NAS_FILTERS_MULTIPLIER = [3, 4, 5, 6]
NAS_SHORTCUT = [0, 1]
NAS_SE = [0, 1]
def get_bottleneck_params_list(var):
"""Get bottleneck_params_list from var.
Args:
var: list, variable list.
Returns:
list, bottleneck_params_list.
"""
params_list = [
1, 16, 1, 1, 3, 1, 0, \
6, 24, 2, 2, 3, 1, 0, \
6, 32, 3, 2, 3, 1, 0, \
6, 64, 4, 2, 3, 1, 0, \
6, 96, 3, 1, 3, 1, 0, \
6, 160, 3, 2, 3, 1, 0, \
6, 320, 1, 1, 3, 1, 0, \
]
for i in range(5):
params_list[i * 7 + 7] = NAS_FILTERS_MULTIPLIER[var[i * 6]]
params_list[i * 7 + 8] = NAS_FILTER_SIZE[i][var[i * 6 + 1]]
params_list[i * 7 + 9] = NAS_LAYERS_NUMBER[i][var[i * 6 + 2]]
params_list[i * 7 + 11] = NAS_KERNEL_SIZE[var[i * 6 + 3]]
params_list[i * 7 + 12] = NAS_SHORTCUT[var[i * 6 + 4]]
params_list[i * 7 + 13] = NAS_SE[var[i * 6 + 5]]
return params_list
class LightNASSpace(SearchSpace):
def __init__(self):
super(LightNASSpace, self).__init__()
def init_tokens(self):
"""Get init tokens in search space.
"""
return [
0, 1, 2, 0, 1, 0, 0, 2, 1, 1, 1, 0, 3, 2, 0, 1, 1, 0, 3, 1, 0, 0, 1,
0, 3, 2, 2, 1, 1, 0
]
def range_table(self):
"""Get range table of current search space.
"""
# [NAS_FILTER_SIZE, NAS_LAYERS_NUMBER, NAS_KERNEL_SIZE, NAS_FILTERS_MULTIPLIER, NAS_SHORTCUT, NAS_SE]
return [
4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2, 2, 4, 3, 3, 2, 2,
2, 4, 3, 3, 2, 2, 2
]
def create_net(self, tokens=None):
"""Create a network for training by tokens.
"""
if tokens is None:
tokens = self.init_tokens()
bottleneck_params_list = get_bottleneck_params_list(tokens)
startup_prog = fluid.Program()
train_prog = fluid.Program()
test_prog = fluid.Program()
train_py_reader, train_cost, train_acc1, train_acc5, global_lr = build_program(
is_train=True,
main_prog=train_prog,
startup_prog=startup_prog,
bottleneck_params_list=bottleneck_params_list)
test_py_reader, test_cost, test_acc1, test_acc5 = build_program(
is_train=False,
main_prog=test_prog,
startup_prog=startup_prog,
bottleneck_params_list=bottleneck_params_list)
test_prog = test_prog.clone(for_test=True)
train_batch_size = batch_size / 1
test_batch_size = batch_size
train_reader = paddle.batch(
paddle.dataset.mnist.train(),
batch_size=train_batch_size,
drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=test_batch_size)
with fluid.program_guard(train_prog, startup_prog):
train_py_reader.decorate_paddle_reader(train_reader)
with fluid.program_guard(test_prog, startup_prog):
test_py_reader.decorate_paddle_reader(test_reader)
return startup_prog, train_prog, test_prog, (
train_cost, train_acc1, train_acc5,
global_lr), (test_cost, test_acc1,
test_acc5), train_py_reader, test_py_reader
def build_program(is_train,
main_prog,
startup_prog,
bottleneck_params_list=None):
with fluid.program_guard(main_prog, startup_prog):
py_reader = fluid.layers.py_reader(
capacity=16,
shapes=[[-1] + image_shape, [-1, 1]],
lod_levels=[0, 0],
dtypes=["float32", "int64"],
use_double_buffer=False)
with fluid.unique_name.guard():
image, label = fluid.layers.read_file(py_reader)
model = LightNASNet()
avg_cost, acc_top1, acc_top5 = net_config(
image,
label,
model,
class_dim=10,
bottleneck_params_list=bottleneck_params_list,
scale_loss=1.0)
avg_cost.persistable = True
acc_top1.persistable = True
acc_top5.persistable = True
if is_train:
params = model.params
params["total_images"] = total_images
params["lr"] = lr
params["num_epochs"] = num_epochs
params["learning_strategy"]["batch_size"] = batch_size
params["learning_strategy"]["name"] = lr_strategy
params["l2_decay"] = l2_decay
params["momentum_rate"] = momentum_rate
optimizer = optimizer_setting(params)
optimizer.minimize(avg_cost)
global_lr = optimizer._global_learning_rate()
if is_train:
return py_reader, avg_cost, acc_top1, acc_top5, global_lr
else:
return py_reader, avg_cost, acc_top1, acc_top5
def net_config(image,
label,
model,
class_dim=1000,
bottleneck_params_list=None,
scale_loss=1.0):
bottleneck_params_list = [
bottleneck_params_list[i:i + 7]
for i in range(0, len(bottleneck_params_list), 7)
]
out = model.net(input=image,
bottleneck_params_list=bottleneck_params_list,
class_dim=class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(
out, label, return_softmax=True)
if scale_loss > 1:
avg_cost = fluid.layers.mean(x=cost) * float(scale_loss)
else:
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=pred, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=pred, label=label, k=5)
return avg_cost, acc_top1, acc_top5
def optimizer_setting(params):
"""optimizer setting.
Args:
params: dict, params.
"""
ls = params["learning_strategy"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
if ls["name"] == "piecewise_decay":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
bd = [step * e for e in ls["epochs"]]
base_lr = params["lr"]
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "cosine_decay":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
step = int(total_images / batch_size + 1)
lr = params["lr"]
num_epochs = params["num_epochs"]
optimizer = fluid.optimizer.Momentum(
learning_rate=fluid.layers.cosine_decay(
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "cosine_warmup_decay":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
l2_decay = params["l2_decay"]
momentum_rate = params["momentum_rate"]
step = int(math.ceil(float(total_images) / batch_size))
lr = params["lr"]
num_epochs = params["num_epochs"]
optimizer = fluid.optimizer.Momentum(
learning_rate=cosine_decay_with_warmup(
learning_rate=lr, step_each_epoch=step, epochs=num_epochs),
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "linear_decay":
if "total_images" not in params:
total_images = IMAGENET1000
else:
total_images = params["total_images"]
batch_size = ls["batch_size"]
num_epochs = params["num_epochs"]
start_lr = params["lr"]
end_lr = 0
total_step = int((total_images / batch_size) * num_epochs)
lr = fluid.layers.polynomial_decay(
start_lr, total_step, end_lr, power=1)
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
elif ls["name"] == "adam":
lr = params["lr"]
optimizer = fluid.optimizer.Adam(learning_rate=lr)
else:
lr = params["lr"]
optimizer = fluid.optimizer.Momentum(
learning_rate=lr,
momentum=momentum_rate,
regularization=fluid.regularizer.L2Decay(l2_decay))
return optimizer
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LightNASNet."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ['LightNASNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class LightNASNet(object):
"""LightNASNet."""
def __init__(self):
self.params = train_parameters
def net(self, input, bottleneck_params_list=None, class_dim=1000,
scale=1.0):
"""Build network.
Args:
input: Variable, input.
class_dim: int, class dim.
scale: float, scale.
Returns:
Variable, network output.
"""
if bottleneck_params_list is None:
# MobileNetV2
# bottleneck_params_list = [
# (1, 16, 1, 1, 3, 1, 0),
# (6, 24, 2, 2, 3, 1, 0),
# (6, 32, 3, 2, 3, 1, 0),
# (6, 64, 4, 2, 3, 1, 0),
# (6, 96, 3, 1, 3, 1, 0),
# (6, 160, 3, 2, 3, 1, 0),
# (6, 320, 1, 1, 3, 1, 0),
# ]
bottleneck_params_list = [
(1, 16, 1, 1, 3, 1, 0),
(3, 24, 3, 2, 3, 1, 0),
(3, 40, 3, 2, 5, 1, 0),
(6, 80, 3, 2, 5, 1, 0),
(6, 96, 2, 1, 3, 1, 0),
(6, 192, 4, 2, 5, 1, 0),
(6, 320, 1, 1, 3, 1, 0),
]
#conv1
input = self.conv_bn_layer(
input,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1,
if_act=True,
name='conv1_1')
# bottleneck sequences
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s, k, ifshortcut, ifse = layer_setting
i += 1
input = self.invresi_blocks(
input=input,
in_channel=in_c,
expansion=t,
out_channel=int(c * scale),
num_layers=n,
stride=s,
filter_size=k,
shortcut=ifshortcut,
squeeze=ifse,
name='conv' + str(i))
in_c = int(c * scale)
#last_conv
input = self.conv_bn_layer(
input=input,
num_filters=int(1280 * scale) if scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding=0,
if_act=True,
name='conv9')
input = fluid.layers.pool2d(
input=input,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
"""Build convolution and batch normalization layers.
Args:
input: Variable, input.
filter_size: int, filter size.
num_filters: int, number of filters.
stride: int, stride.
padding: int, padding.
num_groups: int, number of groups.
if_act: bool, whether using activation.
name: str, name.
use_cudnn: bool, whether use cudnn.
Returns:
Variable, layers output.
"""
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.relu6(bn)
else:
return bn
def shortcut(self, input, data_residual):
"""Build shortcut layer.
Args:
input: Variable, input.
data_residual: Variable, residual layer.
Returns:
Variable, layer output.
"""
return fluid.layers.elementwise_add(input, data_residual)
def squeeze_excitation(self,
input,
num_channels,
reduction_ratio,
name=None):
"""Build squeeze excitation layers.
Args:
input: Variable, input.
num_channels: int, number of channels.
reduction_ratio: float, reduction ratio.
name: str, name.
Returns:
Variable, layers output.
"""
pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc(
input=pool,
size=num_channels // reduction_ratio,
act='relu',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_sqz_weights'),
bias_attr=ParamAttr(name=name + '_sqz_offset'))
stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)
excitation = fluid.layers.fc(
input=squeeze,
size=num_channels,
act='sigmoid',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name=name + '_exc_weights'),
bias_attr=ParamAttr(name=name + '_exc_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale
def inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
ifshortcut,
ifse,
stride,
filter_size,
expansion_factor,
reduction_ratio=4,
name=None):
"""Build inverted residual unit.
Args:
input(Variable): Theinput.
num_in_filter(int): The number of input filters.
num_filters(int): The number of filters.
ifshortcut(bool): Whether to use shortcut.
stride(int): The stride.
filter_size(int): The filter size.
padding(int): The padding.
expansion_factor(float): Expansion factor.
name(str): The name.
Returns:
Variable, layers output.
"""
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = self.conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name=name + '_expand')
bottleneck_conv = self.conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=int((filter_size - 1) / 2),
num_groups=num_expfilter,
if_act=True,
name=name + '_dwise',
use_cudnn=False)
linear_out = self.conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=False,
name=name + '_linear')
out = linear_out
if ifshortcut:
out = self.shortcut(input=input, data_residual=out)
if ifse:
scale = self.squeeze_excitation(
input=linear_out,
num_channels=num_filters,
reduction_ratio=reduction_ratio,
name=name + '_fc')
out = fluid.layers.elementwise_add(x=out, y=scale, act='relu')
return out
def invresi_blocks(self,
input,
in_channel,
expansion,
out_channel,
num_layers,
stride,
filter_size,
shortcut,
squeeze,
name=None):
"""Build inverted residual blocks.
Args:
input(Variable): The input feture map.
in_channel(int): The number of input channel.
expansion(float): Expansion factor.
out_channel(int): The number of output channel.
num_layers(int): The number of layers.
stride(int): The stride.
filter_size(int): The size of filter.
shortcut(bool): Whether to add shortcut layers.
squeeze(bool): Whether to add squeeze excitation layers.
name(str): The name.
Returns:
Variable, layers output.
"""
first_block = self.inverted_residual_unit(
input=input,
num_in_filter=in_channel,
num_filters=out_channel,
ifshortcut=False,
ifse=squeeze,
stride=stride,
filter_size=filter_size,
expansion_factor=expansion,
name=name + '_1')
last_residual_block = first_block
last_c = out_channel
for i in range(1, num_layers):
last_residual_block = self.inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=out_channel,
ifshortcut=shortcut,
ifse=squeeze,
stride=1,
filter_size=filter_size,
expansion_factor=expansion,
name=name + '_' + str(i + 1))
return last_residual_block
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import unittest
import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
class TestFilterPruning(unittest.TestCase):
def test_compression(self):
"""
Model: mobilenet_v1
data: mnist
step1: Training one epoch
step2: pruning flops
step3: fine-tune one epoch
step4: check top1_acc.
"""
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
out = MobileNet("auto_pruning").net(input=image, class_dim=class_dim)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=False)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
val_feed_list = [('img', image.name), ('label', label.name)]
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
acc_top5.name)]
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
train_feed_list = [('img', image.name), ('label', label.name)]
train_fetch_list = [('loss', avg_cost.name)]
com_pass = Compressor(
place,
fluid.global_scope(),
fluid.default_main_program(),
train_reader=train_reader,
train_feed_list=train_feed_list,
train_fetch_list=train_fetch_list,
eval_program=val_program,
eval_reader=val_reader,
eval_feed_list=val_feed_list,
eval_fetch_list=val_fetch_list,
train_optimizer=optimizer)
com_pass.config('./auto_pruning/compress.yaml')
eval_graph = com_pass.run()
if __name__ == '__main__':
unittest.main()
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
import unittest
import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
import sys
sys.path.append("./light_nas")
from light_nas_space import LightNASSpace
class TestLightNAS(unittest.TestCase):
def test_compression(self):
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
space = LightNASSpace()
startup_prog, train_prog, test_prog, train_metrics, test_metrics, train_reader, test_reader = space.create_net(
)
train_cost, train_acc1, train_acc5, global_lr = train_metrics
test_cost, test_acc1, test_acc5 = test_metrics
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_prog)
val_fetch_list = [('acc_top1', test_acc1.name), ('acc_top5',
test_acc5.name)]
train_fetch_list = [('loss', train_cost.name)]
com_pass = Compressor(
place,
fluid.global_scope(),
train_prog,
train_reader=train_reader,
train_feed_list=None,
train_fetch_list=train_fetch_list,
eval_program=test_prog,
eval_reader=test_reader,
eval_feed_list=None,
eval_fetch_list=val_fetch_list,
train_optimizer=None,
search_space=space)
com_pass.config('./light_nas/compress.yaml')
eval_graph = com_pass.run()
if __name__ == '__main__':
unittest.main()
......@@ -118,6 +118,8 @@ packages=['paddle',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.slim.distillation',
'paddle.fluid.contrib.slim.nas',
'paddle.fluid.contrib.slim.searcher',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册