提交 c338b6c2 编写于 作者: Q qjing666

fix zmq bug in python3

上级 c3bee06d
...@@ -4,7 +4,7 @@ import random ...@@ -4,7 +4,7 @@ import random
def recv_and_parse_kv(socket): def recv_and_parse_kv(socket):
message = socket.recv() message = socket.recv()
group = message.split("\t") group = message.decode().split("\t")
if group[0] == "alive": if group[0] == "alive":
return group[0], "0" return group[0], "0"
else: else:
...@@ -23,9 +23,9 @@ class FLServerAgent(object): ...@@ -23,9 +23,9 @@ class FLServerAgent(object):
def connect_scheduler(self): def connect_scheduler(self):
while True: while True:
self.socket.send("SERVER_EP\t{}".format(self.current_ep)) self.socket.send_string("SERVER_EP\t{}".format(self.current_ep))
message = self.socket.recv() message = self.socket.recv()
group = message.split("\t") group = message.decode().split("\t")
if group[0] == 'INIT': if group[0] == 'INIT':
break break
...@@ -39,14 +39,14 @@ class FLWorkerAgent(object): ...@@ -39,14 +39,14 @@ class FLWorkerAgent(object):
def connect_scheduler(self): def connect_scheduler(self):
while True: while True:
self.socket.send("WORKER_EP\t{}".format(self.current_ep)) self.socket.send_string("WORKER_EP\t{}".format(self.current_ep))
message = self.socket.recv() message = self.socket.recv()
group = message.split("\t") group = message.decode().split("\t")
if group[0] == 'INIT': if group[0] == 'INIT':
break break
def finish_training(self): def finish_training(self):
self.socket.send("FINISH\t{}".format(self.current_ep)) self.socket.send_string("FINISH\t{}".format(self.current_ep))
key, value = recv_and_parse_kv(self.socket) key, value = recv_and_parse_kv(self.socket)
if key == "WAIT": if key == "WAIT":
time.sleep(3) time.sleep(3)
...@@ -54,7 +54,7 @@ class FLWorkerAgent(object): ...@@ -54,7 +54,7 @@ class FLWorkerAgent(object):
return False return False
def can_join_training(self): def can_join_training(self):
self.socket.send("JOIN\t{}".format(self.current_ep)) self.socket.send_string("JOIN\t{}".format(self.current_ep))
key, value = recv_and_parse_kv(self.socket) key, value = recv_and_parse_kv(self.socket)
if key == "ACCEPT": if key == "ACCEPT":
...@@ -91,13 +91,13 @@ class FLScheduler(object): ...@@ -91,13 +91,13 @@ class FLScheduler(object):
key, value = recv_and_parse_kv(self.socket) key, value = recv_and_parse_kv(self.socket)
if key == WORKER_EP: if key == WORKER_EP:
self.fl_workers.append(value) self.fl_workers.append(value)
self.socket.send("INIT\t{}".format(value)) self.socket.send_string("INIT\t{}".format(value))
elif key == SERVER_EP: elif key == SERVER_EP:
self.fl_servers.append(value) self.fl_servers.append(value)
self.socket.send("INIT\t{}".format(value)) self.socket.send_string("INIT\t{}".format(value))
else: else:
time.sleep(3) time.sleep(3)
self.socket.send("REJECT\t0") self.socket.send_string("REJECT\t0")
if len(self.fl_workers) == self.worker_num and \ if len(self.fl_workers) == self.worker_num and \
len(self.fl_servers) == self.server_num: len(self.fl_servers) == self.server_num:
ready = True ready = True
...@@ -122,12 +122,12 @@ class FLScheduler(object): ...@@ -122,12 +122,12 @@ class FLScheduler(object):
if worker_dict[value] == 0: if worker_dict[value] == 0:
ready_workers.append(value) ready_workers.append(value)
worker_dict[value] = 1 worker_dict[value] = 1
self.socket.send("ACCEPT\t0") self.socket.send_string("ACCEPT\t0")
continue continue
else: else:
if value not in ready_workers: if value not in ready_workers:
ready_workers.append(value) ready_workers.append(value)
self.socket.send("REJECT\t0") self.socket.send_string("REJECT\t0")
if len(ready_workers) == len(self.fl_workers): if len(ready_workers) == len(self.fl_workers):
all_ready_to_train = True all_ready_to_train = True
...@@ -137,9 +137,9 @@ class FLScheduler(object): ...@@ -137,9 +137,9 @@ class FLScheduler(object):
key, value = recv_and_parse_kv(self.socket) key, value = recv_and_parse_kv(self.socket)
if key == "FINISH": if key == "FINISH":
finish_training_dict[value] = 1 finish_training_dict[value] = 1
self.socket.send("WAIT\t0") self.socket.send_string("WAIT\t0")
else: else:
self.socket.send("REJECT\t0") self.socket.send_string("REJECT\t0")
if len(finish_training_dict) == len(worker_dict): if len(finish_training_dict) == len(worker_dict):
all_finish_training = True all_finish_training = True
time.sleep(5) time.sleep(5)
......
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PaddleFL version string """ """ PaddleFL version string """
fl_version = "0.1.3" fl_version = "0.1.4"
module_proto_version = "0.1.3" module_proto_version = "0.1.4"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册