提交 2b84f18e 编写于 作者: X xiteng1988 提交者: whs

fix next_token of controller test=release/1.6 (#20128)

上级 411f7b42
......@@ -133,11 +133,18 @@ class LightNASStrategy(Strategy):
self._retrain_epoch == 0 or
(context.epoch_id - self.start_epoch) % self._retrain_epoch == 0):
_logger.info("light nas strategy on_epoch_begin")
min_flops = -1
for _ in range(self._max_try_times):
startup_p, train_p, test_p, _, _, train_reader, test_reader = context.search_space.create_net(
self._current_tokens)
context.eval_graph.program = test_p
flops = context.eval_graph.flops()
if min_flops == -1:
min_flops = flops
min_tokens = self._current_tokens[:]
else:
if flops < min_flops:
min_tokens = self._current_tokens[:]
if self._max_latency > 0:
latency = context.search_space.get_model_latency(test_p)
_logger.info("try [{}] with latency {} flops {}".format(
......@@ -147,7 +154,8 @@ class LightNASStrategy(Strategy):
self._current_tokens, flops))
if flops > self._max_flops or (self._max_latency > 0 and
latency > self._max_latency):
self._current_tokens = self._search_agent.next_tokens()
self._current_tokens = self._controller.next_tokens(
min_tokens)
else:
break
......
......@@ -123,10 +123,13 @@ class SAController(EvolutionaryController):
_logger.info("current_reward: {}; current tokens: {}".format(
self._reward, self._tokens))
def next_tokens(self):
def next_tokens(self, control_token=None):
"""
Get next tokens.
"""
if control_token:
tokens = control_token[:]
else:
tokens = self._tokens
new_tokens = tokens[:]
index = int(len(self._range_table) * np.random.random())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册