提交 1c57d554 编写于 作者: M MrChengmo

ps_graph support ps-gpu

上级 4efcb9df
...@@ -259,7 +259,7 @@ class DistributedStrategy(object): ...@@ -259,7 +259,7 @@ class DistributedStrategy(object):
def a_sync(self, flag): def a_sync(self, flag):
if isinstance(flag, bool): if isinstance(flag, bool):
self.strategy.a_sync = flag self.strategy.a_sync = flag
self.a_sync_configs = {"k_steps": 0} self.a_sync_configs = {"k_steps": 0, "worker_device": 'cpu'}
else: else:
raise ValueError( raise ValueError(
"The type of `flag` is invalid, expected type is bool, but received %s". "The type of `flag` is invalid, expected type is bool, but received %s".
......
...@@ -681,8 +681,12 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -681,8 +681,12 @@ class PaddleCloudRoleMaker(RoleMakerBase):
else: else:
self._worker_endpoints = [] self._worker_endpoints = []
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) trainers_num = os.getenv("PADDLE_TRAINERS_NUM", None)
training_role = os.environ["TRAINING_ROLE"] assert trainers_num != None
trainers_num = int(trainers_num)
training_role = os.getenv("TRAINING_ROLE", None)
assert training_role != None
if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]: if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
raise ValueError( raise ValueError(
...@@ -716,19 +720,25 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -716,19 +720,25 @@ class PaddleCloudRoleMaker(RoleMakerBase):
if training_role == "TRAINER": if training_role == "TRAINER":
role = Role.WORKER role = Role.WORKER
current_id = int(os.environ["PADDLE_TRAINER_ID"]) current_id = os.getenv("PADDLE_TRAINER_ID", None)
assert current_id != None
current_id = int(current_id)
if len(self._worker_endpoints) > 0: if len(self._worker_endpoints) > 0:
self._cur_endpoint = self._worker_endpoints[current_id] self._cur_endpoint = self._worker_endpoints[current_id]
elif training_role == "PSERVER": elif training_role == "PSERVER":
role = Role.SERVER role = Role.SERVER
port = os.environ["PADDLE_PORT"] port = os.getenv("PADDLE_PORT", None)
ip = os.environ["POD_IP"] assert port != None
ip = os.getenv("POD_IP", None)
assert ip != None
self._cur_endpoint = ip + ":" + port self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint) current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER": elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"] cur_port = os.getenv("PADDLE_PORT", None)
cur_port = os.environ["PADDLE_PORT"] assert port != None
cur_ip = os.getenv("POD_IP", None)
assert cur_ip != None
curr_endpoint = ":".join([cur_ip, cur_port]) curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint) current_id = heter_trainer_eplist.index(curr_endpoint)
else: else:
......
...@@ -31,6 +31,10 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): ...@@ -31,6 +31,10 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer):
if k_steps < 0: if k_steps < 0:
return False return False
device = self.user_defined_strategy.a_sync_configs["worker_device"]
if device.upper() != 'CPU':
return False
if self.role_maker._is_server(): if self.role_maker._is_server():
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册