diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index 5dcbd7bb64bf4460371d523a0f745e2490a7b3a0..fd6575e04b16d6491ad02dc7a6e1725c68505460 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -50,9 +50,11 @@ class ControllerClient(object): tokens = ",".join([str(token) for token in tokens]) socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) .encode()) - tokens = socket_client.recv(1024).decode() - tokens = [int(token) for token in tokens.strip("\n").split(",")] - return tokens + response = socket_client.recv(1024).decode() + if response.trip('\n').split("\t") == "ok": + return True + else: + return False def next_tokens(self): """ diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 74b954db3bb1c4520551e82b5e8ba3b9514c549c..ac24df86030aae8cb286452b6bd6eeb7b5c80741 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -117,9 +117,10 @@ class ControllerServer(object): reward = messages[2] tokens = [int(token) for token in tokens.split(",")] self._controller.update(tokens, float(reward)) - tokens = self._controller.next_tokens() - tokens = ",".join([str(token) for token in tokens]) - conn.send(tokens.encode()) + #tokens = self._controller.next_tokens() + #tokens = ",".join([str(token) for token in tokens]) + response = "ok" + conn.send(response.encode()) _logger.debug("send message to {}: [{}]".format(addr, tokens)) conn.close() diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index cfc747b0ec9f977dc4e41d2fb128b29823cfd3a3..bbee0d8db641c5b61d520e5a8043721893e86ef5 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -140,6 +140,8 @@ class SANAS(object): Return reward of current searched network. Args: score(float): The score of current searched network. + Returns: + bool: True means updating successfully while false means failure. """ - self._controller_client.update(self._current_tokens, score) self._iter += 1 + return self._controller_client.update(self._current_tokens, score)