提交 0ae8e939 编写于 作者: W wanghaoshuang

Fix reward function of sa nas to make it not return next token.

上级 72c800e9
......@@ -50,9 +50,11 @@ class ControllerClient(object):
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
tokens = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")]
return tokens
response = socket_client.recv(1024).decode()
if response.trip('\n').split("\t") == "ok":
return True
else:
return False
def next_tokens(self):
"""
......
......@@ -117,9 +117,10 @@ class ControllerServer(object):
reward = messages[2]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
tokens = self._controller.next_tokens()
tokens = ",".join([str(token) for token in tokens])
conn.send(tokens.encode())
#tokens = self._controller.next_tokens()
#tokens = ",".join([str(token) for token in tokens])
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
......
......@@ -140,6 +140,8 @@ class SANAS(object):
Return reward of current searched network.
Args:
score(float): The score of current searched network.
Returns:
bool: True means updating successfully while false means failure.
"""
self._controller_client.update(self._current_tokens, score)
self._iter += 1
return self._controller_client.update(self._current_tokens, score)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册