From 0ae8e939ed3f2456f55f1df665d5a295cda13f2b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 19 Nov 2019 12:05:22 +0800 Subject: [PATCH] Fix reward function of sa nas to make it not return next token. --- paddleslim/common/controller_client.py | 8 +++++--- paddleslim/common/controller_server.py | 7 ++++--- paddleslim/nas/sa_nas.py | 4 +++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index 5dcbd7bb..fd6575e0 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -50,9 +50,11 @@ class ControllerClient(object): tokens = ",".join([str(token) for token in tokens]) socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) .encode()) - tokens = socket_client.recv(1024).decode() - tokens = [int(token) for token in tokens.strip("\n").split(",")] - return tokens + response = socket_client.recv(1024).decode() + if response.trip('\n').split("\t") == "ok": + return True + else: + return False def next_tokens(self): """ diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 74b954db..ac24df86 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -117,9 +117,10 @@ class ControllerServer(object): reward = messages[2] tokens = [int(token) for token in tokens.split(",")] self._controller.update(tokens, float(reward)) - tokens = self._controller.next_tokens() - tokens = ",".join([str(token) for token in tokens]) - conn.send(tokens.encode()) + #tokens = self._controller.next_tokens() + #tokens = ",".join([str(token) for token in tokens]) + response = "ok" + conn.send(response.encode()) _logger.debug("send message to {}: [{}]".format(addr, tokens)) conn.close() diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index cfc747b0..bbee0d8d 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -140,6 +140,8 @@ class SANAS(object): Return reward of current searched network. Args: score(float): The score of current searched network. + Returns: + bool: True means updating successfully while false means failure. """ - self._controller_client.update(self._current_tokens, score) self._iter += 1 + return self._controller_client.update(self._current_tokens, score) -- GitLab