From 3785f08f68c7875434f25c0bcd8805e54c66b0c1 Mon Sep 17 00:00:00 2001 From: xiteng1988 Date: Mon, 30 Sep 2019 11:16:09 +0800 Subject: [PATCH] fix next_tokens of controller (#20060) * fix next_tokens of controller --- .../fluid/contrib/slim/nas/light_nas_strategy.py | 10 +++++++++- .../paddle/fluid/contrib/slim/searcher/controller.py | 7 +++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py b/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py index 2723ed5f16f..2ce1a3d0600 100644 --- a/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py +++ b/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py @@ -133,11 +133,18 @@ class LightNASStrategy(Strategy): self._retrain_epoch == 0 or (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): _logger.info("light nas strategy on_epoch_begin") + min_flops = -1 for _ in range(self._max_try_times): startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net( self._current_tokens) context.eval_graph.program = test_p flops = context.eval_graph.flops() + if min_flops == -1: + min_flops = flops + min_tokens = self._current_tokens[:] + else: + if flops < min_flops: + min_tokens = self._current_tokens[:] if self._max_latency > 0: latency = context.search_space.get_model_latency(test_p) _logger.info("try [{}] with latency {} flops {}".format( @@ -147,7 +154,8 @@ class LightNASStrategy(Strategy): self._current_tokens, flops)) if flops > self._max_flops or (self._max_latency > 0 and latency > self._max_latency): - self._current_tokens = self._search_agent.next_tokens() + self._current_tokens = self._controller.next_tokens( + min_tokens) else: break diff --git a/python/paddle/fluid/contrib/slim/searcher/controller.py b/python/paddle/fluid/contrib/slim/searcher/controller.py index 7072dc73746..c4a2555b6d1 100644 --- a/python/paddle/fluid/contrib/slim/searcher/controller.py +++ b/python/paddle/fluid/contrib/slim/searcher/controller.py @@ -123,11 +123,14 @@ class SAController(EvolutionaryController): _logger.info("current_reward: {}; current tokens: {}".format( self._reward, self._tokens)) - def next_tokens(self): + def next_tokens(self, control_token=None): """ Get next tokens. """ - tokens = self._tokens + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens new_tokens = tokens[:] index = int(len(self._range_table) * np.random.random()) new_tokens[index] = ( -- GitLab