提交 b6a33ceb 编写于 作者: W wanghaoshuang

Fix iter.

上级 c5d3eeb4
......@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
......@@ -48,8 +48,8 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward)
.encode())
socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
iter).encode())
response = socket_client.recv(1024).decode()
if response.strip('\n').split("\t") == "ok":
return True
......
......@@ -93,14 +93,15 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr,
message))
messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key):
if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format(
addr, message))
continue
tokens = messages[1]
reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward))
self._controller.update(tokens, float(reward), iter)
response = "ok"
conn.send(response.encode())
_logger.debug("send message to {}: [{}]".format(addr,
......
......@@ -65,15 +65,16 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key]
return d
def update(self, tokens, reward):
def update(self, tokens, reward, iter):
"""
Update the controller according to latest tokens and reward.
Args:
tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens.
"""
self._iter += 1
temperature = self._init_temperature * self._reduce_rate**self._iter
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)):
self._reward = reward
......
......@@ -112,4 +112,5 @@ class SANAS(object):
bool: True means updating successfully while false means failure.
"""
self._iter += 1
return self._controller_client.update(self._current_tokens, score)
return self._controller_client.update(self._current_tokens, score,
self._iter)
......@@ -212,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope)
self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score)
self._controller_client.update(tokens, score, self._iter)
self._iter += 1
def _restore(self, scope):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册