提交 1799c257 编写于 作者: J Jiancheng Li 提交者: whs

Update Light-NAS to support latency-aware search (#19050)

* update light_nas_strategy: add latency constraint

test=develop

* update light_nas_strategy: update get_model_latency

test=develop

* update light_nas_strategy: add more check

test=develop

* update light_nas test

test=develop

* update light_nas test

    test=develop

* minor update light_nas test

    test=develop

* minor update light_nas test

test=develop

* update light_nas test

test=develop

* update _constrain_func of light_nas_strategy

test=develop

* update _constrain_func of light_nas_strategy

test=develop

* remove unused code

test=develop
上级 0fe72469
...@@ -40,6 +40,7 @@ class LightNASStrategy(Strategy): ...@@ -40,6 +40,7 @@ class LightNASStrategy(Strategy):
controller=None, controller=None,
end_epoch=1000, end_epoch=1000,
target_flops=629145600, target_flops=629145600,
target_latency=0,
retrain_epoch=1, retrain_epoch=1,
metric_name='top1_acc', metric_name='top1_acc',
server_ip=None, server_ip=None,
...@@ -53,6 +54,7 @@ class LightNASStrategy(Strategy): ...@@ -53,6 +54,7 @@ class LightNASStrategy(Strategy):
controller(searcher.Controller): The searching controller. Default: None. controller(searcher.Controller): The searching controller. Default: None.
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0 end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
target_flops(int): The constraint of FLOPS. target_flops(int): The constraint of FLOPS.
target_latency(float): The constraint of latency.
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1. retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
metric_name(str): The metric used to evaluate the model. metric_name(str): The metric used to evaluate the model.
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc' It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
...@@ -66,6 +68,7 @@ class LightNASStrategy(Strategy): ...@@ -66,6 +68,7 @@ class LightNASStrategy(Strategy):
self.start_epoch = 0 self.start_epoch = 0
self.end_epoch = end_epoch self.end_epoch = end_epoch
self._max_flops = target_flops self._max_flops = target_flops
self._max_latency = target_latency
self._metric_name = metric_name self._metric_name = metric_name
self._controller = controller self._controller = controller
self._retrain_epoch = 0 self._retrain_epoch = 0
...@@ -86,8 +89,6 @@ class LightNASStrategy(Strategy): ...@@ -86,8 +89,6 @@ class LightNASStrategy(Strategy):
def on_compression_begin(self, context): def on_compression_begin(self, context):
self._current_tokens = context.search_space.init_tokens() self._current_tokens = context.search_space.init_tokens()
constrain_func = functools.partial(
self._constrain_func, context=context)
self._controller.reset(context.search_space.range_table(), self._controller.reset(context.search_space.range_table(),
self._current_tokens, None) self._current_tokens, None)
...@@ -127,15 +128,6 @@ class LightNASStrategy(Strategy): ...@@ -127,15 +128,6 @@ class LightNASStrategy(Strategy):
d[key] = self.__dict__[key] d[key] = self.__dict__[key]
return d return d
def _constrain_func(self, tokens, context=None):
"""Check whether the tokens meet constraint."""
_, _, test_prog, _, _, _, _ = context.search_space.create_net(tokens)
flops = GraphWrapper(test_prog).flops()
if flops <= self._max_flops:
return True
else:
return False
def on_epoch_begin(self, context): def on_epoch_begin(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and ( if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or self._retrain_epoch == 0 or
...@@ -144,13 +136,20 @@ class LightNASStrategy(Strategy): ...@@ -144,13 +136,20 @@ class LightNASStrategy(Strategy):
for _ in range(self._max_try_times): for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net( startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens) self._current_tokens)
_logger.info("try [{}]".format(self._current_tokens))
context.eval_graph.program = test_p context.eval_graph.program = test_p
flops = context.eval_graph.flops() flops = context.eval_graph.flops()
if flops <= self._max_flops: if self._max_latency > 0:
break latency = context.search_space.get_model_latency(test_p)
_logger.info("try [{}] with latency {} flops {}".format(
self._current_tokens, latency, flops))
else: else:
_logger.info("try [{}] with flops {}".format(
self._current_tokens, flops))
if flops > self._max_flops or (self._max_latency > 0 and
latency > self._max_latency):
self._current_tokens = self._search_agent.next_tokens() self._current_tokens = self._search_agent.next_tokens()
else:
break
context.train_reader = train_reader context.train_reader = train_reader
context.eval_reader = test_reader context.eval_reader = test_reader
...@@ -173,7 +172,17 @@ class LightNASStrategy(Strategy): ...@@ -173,7 +172,17 @@ class LightNASStrategy(Strategy):
flops = context.eval_graph.flops() flops = context.eval_graph.flops()
if flops > self._max_flops: if flops > self._max_flops:
self._current_reward = 0.0 self._current_reward = 0.0
_logger.info("reward: {}; flops: {}; tokens: {}".format( if self._max_latency > 0:
self._current_reward, flops, self._current_tokens)) test_p = context.search_space.create_net(self._current_tokens)[
2]
latency = context.search_space.get_model_latency(test_p)
if latency > self._max_latency:
self._current_reward = 0.0
_logger.info("reward: {}; latency: {}; flops: {}; tokens: {}".
format(self._current_reward, latency, flops,
self._current_tokens))
else:
_logger.info("reward: {}; flops: {}; tokens: {}".format(
self._current_reward, flops, self._current_tokens))
self._current_tokens = self._search_agent.update( self._current_tokens = self._search_agent.update(
self._current_tokens, self._current_reward) self._current_tokens, self._current_reward)
...@@ -41,3 +41,12 @@ class SearchSpace(object): ...@@ -41,3 +41,12 @@ class SearchSpace(object):
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics (tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
""" """
raise NotImplementedError('Abstract method.') raise NotImplementedError('Abstract method.')
def get_model_latency(self, program):
"""Get model latency according to program.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
raise NotImplementedError('Abstract method.')
...@@ -10,6 +10,7 @@ strategies: ...@@ -10,6 +10,7 @@ strategies:
class: 'LightNASStrategy' class: 'LightNASStrategy'
controller: 'sa_controller' controller: 'sa_controller'
target_flops: 629145600 target_flops: 629145600
target_latency: 1
end_epoch: 2 end_epoch: 2
retrain_epoch: 1 retrain_epoch: 1
metric_name: 'acc_top1' metric_name: 'acc_top1'
......
...@@ -17,6 +17,7 @@ from light_nasnet import LightNASNet ...@@ -17,6 +17,7 @@ from light_nasnet import LightNASNet
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
import json import json
import random
total_images = 1281167 total_images = 1281167
lr = 0.1 lr = 0.1
...@@ -85,6 +86,16 @@ class LightNASSpace(SearchSpace): ...@@ -85,6 +86,16 @@ class LightNASSpace(SearchSpace):
2, 4, 3, 3, 2, 2, 2 2, 4, 3, 3, 2, 2, 2
] ]
def get_model_latency(self, program):
"""Get model latency according to program.
Returns a random number since it's only for testing.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
return random.randint(1, 2)
def create_net(self, tokens=None): def create_net(self, tokens=None):
"""Create a network for training by tokens. """Create a network for training by tokens.
""" """
......
...@@ -11,24 +11,96 @@ ...@@ -11,24 +11,96 @@
# without warranties or conditions of any kind, either express or implied. # without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and # see the license for the specific language governing permissions and
# limitations under the license. # limitations under the license.
"""
import paddle Test LightNAS.
"""
import sys
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
import sys
sys.path.append("./light_nas") sys.path.append("./light_nas")
from light_nas_space import LightNASSpace from light_nas_space import LightNASSpace
class TestLightNAS(unittest.TestCase): class TestLightNAS(unittest.TestCase):
"""
Test LightNAS.
"""
def test_compression(self): def test_compression(self):
"""
Test LightNAS.
"""
# Update compress.yaml
lines = list()
fid = open('./light_nas/compress.yaml')
for line in fid:
if 'target_latency' in line:
lines.append(' target_latency: 0\n')
else:
lines.append(line)
fid.close()
fid = open('./light_nas/compress.yaml', 'w')
for line in lines:
fid.write(line)
fid.close()
# Begin test
if not fluid.core.is_compiled_with_cuda():
return
space = LightNASSpace()
startup_prog, train_prog, test_prog, train_metrics, test_metrics, train_reader, test_reader = space.create_net(
)
train_cost, train_acc1, train_acc5, global_lr = train_metrics
test_cost, test_acc1, test_acc5 = test_metrics
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_prog)
val_fetch_list = [('acc_top1', test_acc1.name),
('acc_top5', test_acc5.name)]
train_fetch_list = [('loss', train_cost.name)]
com_pass = Compressor(
place,
fluid.global_scope(),
train_prog,
train_reader=train_reader,
train_feed_list=None,
train_fetch_list=train_fetch_list,
eval_program=test_prog,
eval_reader=test_reader,
eval_feed_list=None,
eval_fetch_list=val_fetch_list,
train_optimizer=None,
search_space=space)
com_pass.config('./light_nas/compress.yaml')
eval_graph = com_pass.run()
def test_compression_with_target_latency(self):
"""
Test LightNAS with target_latency.
"""
# Update compress.yaml
lines = list()
fid = open('./light_nas/compress.yaml')
for line in fid:
if 'target_latency' in line:
lines.append(' target_latency: 1\n')
else:
lines.append(line)
fid.close()
fid = open('./light_nas/compress.yaml', 'w')
for line in lines:
fid.write(line)
fid.close()
# Begin test
if not fluid.core.is_compiled_with_cuda(): if not fluid.core.is_compiled_with_cuda():
return return
class_dim = 10
image_shape = [1, 28, 28]
space = LightNASSpace() space = LightNASSpace()
...@@ -41,8 +113,8 @@ class TestLightNAS(unittest.TestCase): ...@@ -41,8 +113,8 @@ class TestLightNAS(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
val_fetch_list = [('acc_top1', test_acc1.name), ('acc_top5', val_fetch_list = [('acc_top1', test_acc1.name),
test_acc5.name)] ('acc_top5', test_acc5.name)]
train_fetch_list = [('loss', train_cost.name)] train_fetch_list = [('loss', train_cost.name)]
com_pass = Compressor( com_pass = Compressor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册