diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index ad989dd16014fa8e6fa1495516e81048324fb826..8a8ebbde3d738438d3cca484ca9c824d853837b2 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -38,7 +38,7 @@ class ControllerClient(object): self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._key = key - def update(self, tokens, reward): + def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. Args: @@ -48,8 +48,8 @@ class ControllerClient(object): socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client.connect((self.server_ip, self.server_port)) tokens = ",".join([str(token) for token in tokens]) - socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) - .encode()) + socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward, + iter).encode()) response = socket_client.recv(1024).decode() if response.strip('\n').split("\t") == "ok": return True diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index 212eb58e2e63c793539a56974208fd0fe47c3501..71f6500d12591622c4a42d81d04fb7c4d5124bb1 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -93,14 +93,15 @@ class ControllerServer(object): _logger.debug("recv message from {}: [{}]".format(addr, message)) messages = message.strip('\n').split("\t") - if (len(messages) < 3) or (messages[0] != self._key): + if (len(messages) < 4) or (messages[0] != self._key): _logger.debug("recv noise from {}: [{}]".format( addr, message)) continue tokens = messages[1] reward = messages[2] + iter = messages[3] tokens = [int(token) for token in tokens.split(",")] - self._controller.update(tokens, float(reward)) + self._controller.update(tokens, float(reward), iter) response = "ok" conn.send(response.encode()) _logger.debug("send message to {}: [{}]".format(addr, diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index b619b818a3208d740c1ddb6753cf5931f3d058f5..9c4b884d01ec8126a0d0a944ba7d2ca62f7db429 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -65,15 +65,16 @@ class SAController(EvolutionaryController): d[key] = self.__dict__[key] return d - def update(self, tokens, reward): + def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. Args: tokens(list): The tokens generated in last step. reward(float): The reward of tokens. """ - self._iter += 1 - temperature = self._init_temperature * self._reduce_rate**self._iter + if iter > self._iter: + self._iter = iter + temperature = self._init_temperature * self._reduce_rate**self._iter if (reward > self._reward) or (np.random.random() <= math.exp( (reward - self._reward) / temperature)): self._reward = reward diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index 593ab40172acf129c03810489e7a516e5e8f6340..83659f94ef1a05b96d723cd305b8a279815ce68d 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -112,4 +112,5 @@ class SANAS(object): bool: True means updating successfully while false means failure. """ self._iter += 1 - return self._controller_client.update(self._current_tokens, score) + return self._controller_client.update(self._current_tokens, score, + self._iter) diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index fba8c11170f3fbf2eddbe15942dc642ad448658b..d09e726dc046003ed103124eeeb911f9aed67572 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -212,7 +212,7 @@ class AutoPruner(object): self._restore(self._scope) self._param_backup = {} tokens = self._ratios2tokens(self._current_ratios) - self._controller_client.update(tokens, score) + self._controller_client.update(tokens, score, self._iter) self._iter += 1 def _restore(self, scope):