提交 1799c257 编写于 作者: J Jiancheng Li 提交者: whs

Update Light-NAS to support latency-aware search (#19050)

* update light_nas_strategy: add latency constraint

test=develop

* update light_nas_strategy: update get_model_latency

test=develop

* update light_nas_strategy: add more check

test=develop

* update light_nas test

test=develop

* update light_nas test

    test=develop

* minor update light_nas test

    test=develop

* minor update light_nas test

test=develop

* update light_nas test

test=develop

* update _constrain_func of light_nas_strategy

test=develop

* update _constrain_func of light_nas_strategy

test=develop

* remove unused code

test=develop
上级 0fe72469
......@@ -40,6 +40,7 @@ class LightNASStrategy(Strategy):
controller=None,
end_epoch=1000,
target_flops=629145600,
target_latency=0,
retrain_epoch=1,
metric_name='top1_acc',
server_ip=None,
......@@ -53,6 +54,7 @@ class LightNASStrategy(Strategy):
controller(searcher.Controller): The searching controller. Default: None.
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. Default: 0
target_flops(int): The constraint of FLOPS.
target_latency(float): The constraint of latency.
retrain_epoch(int): The number of training epochs before evaluating structure generated by controller. Default: 1.
metric_name(str): The metric used to evaluate the model.
It should be one of keys in out_nodes of graph wrapper. Default: 'top1_acc'
......@@ -66,6 +68,7 @@ class LightNASStrategy(Strategy):
self.start_epoch = 0
self.end_epoch = end_epoch
self._max_flops = target_flops
self._max_latency = target_latency
self._metric_name = metric_name
self._controller = controller
self._retrain_epoch = 0
......@@ -86,8 +89,6 @@ class LightNASStrategy(Strategy):
def on_compression_begin(self, context):
self._current_tokens = context.search_space.init_tokens()
constrain_func = functools.partial(
self._constrain_func, context=context)
self._controller.reset(context.search_space.range_table(),
self._current_tokens, None)
......@@ -127,15 +128,6 @@ class LightNASStrategy(Strategy):
d[key] = self.__dict__[key]
return d
def _constrain_func(self, tokens, context=None):
"""Check whether the tokens meet constraint."""
_, _, test_prog, _, _, _, _ = context.search_space.create_net(tokens)
flops = GraphWrapper(test_prog).flops()
if flops <= self._max_flops:
return True
else:
return False
def on_epoch_begin(self, context):
if context.epoch_id >= self.start_epoch and context.epoch_id <= self.end_epoch and (
self._retrain_epoch == 0 or
......@@ -144,13 +136,20 @@ class LightNASStrategy(Strategy):
for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens)
_logger.info("try [{}]".format(self._current_tokens))
context.eval_graph.program = test_p
flops = context.eval_graph.flops()
if flops <= self._max_flops:
break
if self._max_latency > 0:
latency = context.search_space.get_model_latency(test_p)
_logger.info("try [{}] with latency {} flops {}".format(
self._current_tokens, latency, flops))
else:
_logger.info("try [{}] with flops {}".format(
self._current_tokens, flops))
if flops > self._max_flops or (self._max_latency > 0 and
latency > self._max_latency):
self._current_tokens = self._search_agent.next_tokens()
else:
break
context.train_reader = train_reader
context.eval_reader = test_reader
......@@ -173,6 +172,16 @@ class LightNASStrategy(Strategy):
flops = context.eval_graph.flops()
if flops > self._max_flops:
self._current_reward = 0.0
if self._max_latency > 0:
test_p = context.search_space.create_net(self._current_tokens)[
2]
latency = context.search_space.get_model_latency(test_p)
if latency > self._max_latency:
self._current_reward = 0.0
_logger.info("reward: {}; latency: {}; flops: {}; tokens: {}".
format(self._current_reward, latency, flops,
self._current_tokens))
else:
_logger.info("reward: {}; flops: {}; tokens: {}".format(
self._current_reward, flops, self._current_tokens))
self._current_tokens = self._search_agent.update(
......
......@@ -41,3 +41,12 @@ class SearchSpace(object):
(tuple): startup_program, train_program, evaluation_program, train_metrics, test_metrics
"""
raise NotImplementedError('Abstract method.')
def get_model_latency(self, program):
"""Get model latency according to program.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
raise NotImplementedError('Abstract method.')
......@@ -10,6 +10,7 @@ strategies:
class: 'LightNASStrategy'
controller: 'sa_controller'
target_flops: 629145600
target_latency: 1
end_epoch: 2
retrain_epoch: 1
metric_name: 'acc_top1'
......
......@@ -17,6 +17,7 @@ from light_nasnet import LightNASNet
import paddle.fluid as fluid
import paddle
import json
import random
total_images = 1281167
lr = 0.1
......@@ -85,6 +86,16 @@ class LightNASSpace(SearchSpace):
2, 4, 3, 3, 2, 2, 2
]
def get_model_latency(self, program):
"""Get model latency according to program.
Returns a random number since it's only for testing.
Args:
program(Program): The program to get latency.
Return:
(float): model latency.
"""
return random.randint(1, 2)
def create_net(self, tokens=None):
"""Create a network for training by tokens.
"""
......
......@@ -11,24 +11,96 @@
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import paddle
"""
Test LightNAS.
"""
import sys
import unittest
import paddle.fluid as fluid
from mobilenet import MobileNet
from paddle.fluid.contrib.slim.core import Compressor
from paddle.fluid.contrib.slim.graph import GraphWrapper
import sys
sys.path.append("./light_nas")
from light_nas_space import LightNASSpace
class TestLightNAS(unittest.TestCase):
"""
Test LightNAS.
"""
def test_compression(self):
"""
Test LightNAS.
"""
# Update compress.yaml
lines = list()
fid = open('./light_nas/compress.yaml')
for line in fid:
if 'target_latency' in line:
lines.append(' target_latency: 0\n')
else:
lines.append(line)
fid.close()
fid = open('./light_nas/compress.yaml', 'w')
for line in lines:
fid.write(line)
fid.close()
# Begin test
if not fluid.core.is_compiled_with_cuda():
return
space = LightNASSpace()
startup_prog, train_prog, test_prog, train_metrics, test_metrics, train_reader, test_reader = space.create_net(
)
train_cost, train_acc1, train_acc5, global_lr = train_metrics
test_cost, test_acc1, test_acc5 = test_metrics
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_prog)
val_fetch_list = [('acc_top1', test_acc1.name),
('acc_top5', test_acc5.name)]
train_fetch_list = [('loss', train_cost.name)]
com_pass = Compressor(
place,
fluid.global_scope(),
train_prog,
train_reader=train_reader,
train_feed_list=None,
train_fetch_list=train_fetch_list,
eval_program=test_prog,
eval_reader=test_reader,
eval_feed_list=None,
eval_fetch_list=val_fetch_list,
train_optimizer=None,
search_space=space)
com_pass.config('./light_nas/compress.yaml')
eval_graph = com_pass.run()
def test_compression_with_target_latency(self):
"""
Test LightNAS with target_latency.
"""
# Update compress.yaml
lines = list()
fid = open('./light_nas/compress.yaml')
for line in fid:
if 'target_latency' in line:
lines.append(' target_latency: 1\n')
else:
lines.append(line)
fid.close()
fid = open('./light_nas/compress.yaml', 'w')
for line in lines:
fid.write(line)
fid.close()
# Begin test
if not fluid.core.is_compiled_with_cuda():
return
class_dim = 10
image_shape = [1, 28, 28]
space = LightNASSpace()
......@@ -41,8 +113,8 @@ class TestLightNAS(unittest.TestCase):
exe = fluid.Executor(place)
exe.run(startup_prog)
val_fetch_list = [('acc_top1', test_acc1.name), ('acc_top5',
test_acc5.name)]
val_fetch_list = [('acc_top1', test_acc1.name),
('acc_top5', test_acc5.name)]
train_fetch_list = [('loss', train_cost.name)]
com_pass = Compressor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册