提交 0ae8e939 编写于 作者: W wanghaoshuang

Fix reward function of sa nas to make it not return next token.

上级 72c800e9
...@@ -50,9 +50,11 @@ class ControllerClient(object): ...@@ -50,9 +50,11 @@ class ControllerClient(object):
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode()) .encode())
tokens = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")] if response.trip('\n').split("\t") == "ok":
return tokens return True
else:
return False
def next_tokens(self): def next_tokens(self):
""" """
......
...@@ -117,9 +117,10 @@ class ControllerServer(object): ...@@ -117,9 +117,10 @@ class ControllerServer(object):
reward = messages[2] reward = messages[2]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens() #tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens]) #tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode()) response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
tokens)) tokens))
conn.close() conn.close()
......
...@@ -140,6 +140,8 @@ class SANAS(object): ...@@ -140,6 +140,8 @@ class SANAS(object):
Return reward of current searched network. Return reward of current searched network.
Args: Args:
score(float): The score of current searched network. score(float): The score of current searched network.
Returns:
bool: True means updating successfully while false means failure.
""" """
self._controller_client.update(self._current_tokens, score)
self._iter += 1 self._iter += 1
return self._controller_client.update(self._current_tokens, score)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册