diff --git a/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py b/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py index 2723ed5f16f90505eea505eb451c7968cb406a4a..2ce1a3d06007e5ee500474111cba3d9447a53324 100644 --- a/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py +++ b/python/paddle/fluid/contrib/slim/nas/light_nas_strategy.py @@ -133,11 +133,18 @@ class LightNASStrategy(Strategy): self._retrain_epoch == 0 or (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): _logger.info("light nas strategy on_epoch_begin") + min_flops = -1 for _ in range(self._max_try_times): startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net( self._current_tokens) context.eval_graph.program = test_p flops = context.eval_graph.flops() + if min_flops == -1: + min_flops = flops + min_tokens = self._current_tokens[:] + else: + if flops < min_flops: + min_tokens = self._current_tokens[:] if self._max_latency > 0: latency = context.search_space.get_model_latency(test_p) _logger.info("try [{}] with latency {} flops {}".format( @@ -147,7 +154,8 @@ class LightNASStrategy(Strategy): self._current_tokens, flops)) if flops > self._max_flops or (self._max_latency > 0 and latency > self._max_latency): - self._current_tokens = self._search_agent.next_tokens() + self._current_tokens = self._controller.next_tokens( + min_tokens) else: break diff --git a/python/paddle/fluid/contrib/slim/searcher/controller.py b/python/paddle/fluid/contrib/slim/searcher/controller.py index 7072dc73746d1172a9626c60ff50adfe8c9e51b9..c4a2555b6d1351c3e8bfeaacda67160815919cc3 100644 --- a/python/paddle/fluid/contrib/slim/searcher/controller.py +++ b/python/paddle/fluid/contrib/slim/searcher/controller.py @@ -123,11 +123,14 @@ class SAController(EvolutionaryController): _logger.info("current_reward: {}; current tokens: {}".format( self._reward, self._tokens)) - def next_tokens(self): + def next_tokens(self, control_token=None): """ Get next tokens. """ - tokens = self._tokens + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens new_tokens = tokens[:] index = int(len(self._range_table) * np.random.random()) new_tokens[index] = (