提交 3785f08f 编写于 作者: X xiteng1988 提交者: whs

fix next_tokens of controller (#20060)

* fix next_tokens of controller
上级 5f2290ab
...@@ -133,11 +133,18 @@ class LightNASStrategy(Strategy): ...@@ -133,11 +133,18 @@ class LightNASStrategy(Strategy):
self._retrain_epoch == 0 or self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0): (context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
_logger.info("light nas strategy on_epoch_begin") _logger.info("light nas strategy on_epoch_begin")
min_flops = -1
for _ in range(self._max_try_times): for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net( startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens) self._current_tokens)
context.eval_graph.program = test_p context.eval_graph.program = test_p
flops = context.eval_graph.flops() flops = context.eval_graph.flops()
if min_flops == -1:
min_flops = flops
min_tokens = self._current_tokens[:]
else:
if flops < min_flops:
min_tokens = self._current_tokens[:]
if self._max_latency > 0: if self._max_latency > 0:
latency = context.search_space.get_model_latency(test_p) latency = context.search_space.get_model_latency(test_p)
_logger.info("try [{}] with latency {} flops {}".format( _logger.info("try [{}] with latency {} flops {}".format(
...@@ -147,7 +154,8 @@ class LightNASStrategy(Strategy): ...@@ -147,7 +154,8 @@ class LightNASStrategy(Strategy):
self._current_tokens, flops)) self._current_tokens, flops))
if flops > self._max_flops or (self._max_latency > 0 and if flops > self._max_flops or (self._max_latency > 0 and
latency > self._max_latency): latency > self._max_latency):
self._current_tokens = self._search_agent.next_tokens() self._current_tokens = self._controller.next_tokens(
min_tokens)
else: else:
break break
......
...@@ -123,10 +123,13 @@ class SAController(EvolutionaryController): ...@@ -123,10 +123,13 @@ class SAController(EvolutionaryController):
_logger.info("current_reward: {}; current tokens: {}".format( _logger.info("current_reward: {}; current tokens: {}".format(
self._reward, self._tokens)) self._reward, self._tokens))
def next_tokens(self): def next_tokens(self, control_token=None):
""" """
Get next tokens. Get next tokens.
""" """
if control_token:
tokens = control_token[:]
else:
tokens = self._tokens tokens = self._tokens
new_tokens = tokens[:] new_tokens = tokens[:]
index = int(len(self._range_table) * np.random.random()) index = int(len(self._range_table) * np.random.random())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册