提交 b6a33ceb 编写于 作者: W wanghaoshuang

Fix iter.

上级 c5d3eeb4
...@@ -38,7 +38,7 @@ class ControllerClient(object): ...@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -48,8 +48,8 @@ class ControllerClient(object): ...@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
.encode()) iter).encode())
response = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok": if response.strip('\n').split("\t") == "ok":
return True return True
......
...@@ -93,14 +93,15 @@ class ControllerServer(object): ...@@ -93,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key): if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward), iter)
response = "ok" response = "ok"
conn.send(response.encode()) conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
......
...@@ -65,15 +65,16 @@ class SAController(EvolutionaryController): ...@@ -65,15 +65,16 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key] d[key] = self.__dict__[key]
return d return d
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
""" """
self._iter += 1 if iter > self._iter:
temperature = self._init_temperature * self._reduce_rate**self._iter self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
self._reward = reward self._reward = reward
......
...@@ -112,4 +112,5 @@ class SANAS(object): ...@@ -112,4 +112,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure. bool: True means updating successfully while false means failure.
""" """
self._iter += 1 self._iter += 1
return self._controller_client.update(self._current_tokens, score) return self._controller_client.update(self._current_tokens, score,
self._iter)
...@@ -212,7 +212,7 @@ class AutoPruner(object): ...@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope) self._restore(self._scope)
self._param_backup = {} self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios) tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score) self._controller_client.update(tokens, score, self._iter)
self._iter += 1 self._iter += 1
def _restore(self, scope): def _restore(self, scope):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册