提交 1c57d554 编写于 作者: M MrChengmo

ps_graph support ps-gpu

上级 4efcb9df
......@@ -259,7 +259,7 @@ class DistributedStrategy(object):
def a_sync(self, flag):
if isinstance(flag, bool):
self.strategy.a_sync = flag
self.a_sync_configs = {"k_steps": 0}
self.a_sync_configs = {"k_steps": 0, "worker_device": 'cpu'}
else:
raise ValueError(
"The type of `flag` is invalid, expected type is bool, but received %s".
......
......@@ -681,8 +681,12 @@ class PaddleCloudRoleMaker(RoleMakerBase):
else:
self._worker_endpoints = []
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"]
trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
assert trainers_num != None
trainers_num = int(trainers_num)
training_role = os.getenv("TRAINING_ROLE", None)
assert training_role != None
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
raise ValueError(
......@@ -716,19 +720,25 @@ class PaddleCloudRoleMaker(RoleMakerBase):
if training_role == "TRAINER":
role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"])
current_id = os.getenv("PADDLE_TRAINER_ID", None)
assert current_id != None
current_id = int(current_id)
if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER":
role = Role.SERVER
port = os.environ["PADDLE_PORT"]
ip = os.environ["POD_IP"]
port = os.getenv("PADDLE_PORT", None)
assert port != None
ip = os.getenv("POD_IP", None)
assert ip != None
self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"]
cur_port = os.environ["PADDLE_PORT"]
cur_port = os.getenv("PADDLE_PORT", None)
assert port != None
cur_ip = os.getenv("POD_IP", None)
assert cur_ip != None
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)
else:
......
......@@ -31,6 +31,10 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
if k_steps < 0:
return False
device = self.user_defined_strategy.a_sync_configs["worker_device"]
if device.upper() != 'CPU':
return False
if self.role_maker._is_server():
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册