未验证 提交 c98d175d 编写于 作者: Z zmx 提交者: GitHub

[heterps]change default executor for heter trainer (#37314)

* fix pslib. test=develop

* add device to train_from_dataset. test=develop

* refine fleet.stop_worker. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix executor & ut. test=develop

* fix executor & ut. test=develop

* fix executor & ut. test=develop
上级 8fd8780e
...@@ -527,7 +527,8 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -527,7 +527,8 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._next_heter_trainer_endpoints = [] self._next_heter_trainer_endpoints = []
self._previous_heter_trainer_endpoints = [] self._previous_heter_trainer_endpoints = []
self._heter_trainer_endpoints = [] self._heter_trainer_endpoints = []
self._heter_trainer_device = "CPU" self._heter_trainer_device = "cpu"
self._heter_trainer_device_type = "cpu"
self._is_heter_parameter_server_mode = False self._is_heter_parameter_server_mode = False
self._stage_trainers = [] self._stage_trainers = []
...@@ -545,13 +546,21 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -545,13 +546,21 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def _all_reduce(self, input, mode="sum", comm_world="worker"): def _all_reduce(self, input, mode="sum", comm_world="worker"):
return self._gloo.all_reduce(input, mode, comm_world) return self._gloo.all_reduce(input, mode, comm_world)
def _heter_device(self):
"""
return the heter device that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device
def _heter_device_type(self): def _heter_device_type(self):
""" """
return the heter device type that current heter worker is using return the heter device type that current heter worker is using
""" """
if not self._role_is_generated: if not self._role_is_generated:
self._generate_role() self._generate_role()
return self._heter_trainer_device return self._heter_trainer_device_type
def _get_stage_id(self): def _get_stage_id(self):
""" """
...@@ -935,14 +944,24 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -935,14 +944,24 @@ class PaddleCloudRoleMaker(RoleMakerBase):
) )
self._stage_trainers = eval(self._stage_trainers) self._stage_trainers = eval(self._stage_trainers)
self._heter_trainer_device = os.getenv("HETER_DEVICE_TYPE", None) self._heter_trainer_device_type = os.getenv("HETER_DEVICE_TYPE",
if self._heter_trainer_device == None: None)
if self._heter_trainer_device_type == None:
raise ValueError( raise ValueError(
"Can not find HETER_DEVICE_TYPE, please check your environment." "Can not find HETER_DEVICE_TYPE, please check your environment."
) )
assert self._heter_trainer_device in ( assert self._heter_trainer_device_type in (
"cpu", "gpu", "xpu" "cpu", "gpu", "xpu"
), "HETER_DEVICE_TYPE should be cpu,gpu or xpu" ), "HETER_DEVICE_TYPE should be cpu,gpu or xpu"
if self._heter_trainer_device_type == "gpu":
heter_device_id = os.getenv("FLAGS_selected_gpus", "0")
self._heter_trainer_device = ":".join(
(self._heter_trainer_device_type, heter_device_id))
if self._heter_trainer_device == "xpu":
heter_device_id = os.getenv("FLAGS_selected_xpus", "0")
self._heter_trainer_device = ":".join(
(self._heter_trainer_device_type, heter_device_id))
cur_port = os.getenv("PADDLE_PORT", None) cur_port = os.getenv("PADDLE_PORT", None)
if cur_port == None: if cur_port == None:
raise ValueError( raise ValueError(
......
...@@ -331,6 +331,8 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -331,6 +331,8 @@ class ParameterServerOptimizer(MetaOptimizerBase):
if self.role_maker._is_heter_parameter_server_mode: if self.role_maker._is_heter_parameter_server_mode:
_origin_startup_program._heter_pipeline_opt = { _origin_startup_program._heter_pipeline_opt = {
"startup_program": startup_program, "startup_program": startup_program,
"pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
"heter_place": self.role_maker._heter_device(),
} }
loss.block.program._heter_pipeline_opt = { loss.block.program._heter_pipeline_opt = {
...@@ -344,6 +346,7 @@ class ParameterServerOptimizer(MetaOptimizerBase): ...@@ -344,6 +346,7 @@ class ParameterServerOptimizer(MetaOptimizerBase):
int(self.role_maker._get_num_stage()), int(self.role_maker._get_num_stage()),
"section_program": main_program, "section_program": main_program,
"num_microbatches": self.num_microbatches, "num_microbatches": self.num_microbatches,
"heter_place": self.role_maker._heter_device(),
} }
else: else:
loss.block.program = main_program loss.block.program = main_program
......
...@@ -913,11 +913,11 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -913,11 +913,11 @@ class TheOnePSRuntime(RuntimeBase):
def _stop_worker(self): def _stop_worker(self):
self._communicator.stop() self._communicator.stop()
if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker( if self.role_maker._is_heter_parameter_server_mode:
): assert self._heter_client != None, "heter client should not be None in heterps mode"
self._heter_client.stop() self._heter_client.stop()
executor = self._get_executor() #executor = self._get_executor()
executor.close() #executor.close()
@staticmethod @staticmethod
def __exclude_vars(exclude_var_names=[]): def __exclude_vars(exclude_var_names=[]):
......
...@@ -49,7 +49,6 @@ class DownpourServer(Server): ...@@ -49,7 +49,6 @@ class DownpourServer(Server):
self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer" self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer"
self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient" self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient"
self.server_.downpour_server_param.service_param.service_class = "DownpourPsService" self.server_.downpour_server_param.service_param.service_class = "DownpourPsService"
self.server_.downpour_server_param.service_param.start_server_port = 0
self.server_.downpour_server_param.service_param.server_thread_num = 12 self.server_.downpour_server_param.service_param.server_thread_num = 12
def add_sparse_table(self, table_id, learning_rate, slot_key_vars, def add_sparse_table(self, table_id, learning_rate, slot_key_vars,
......
...@@ -1296,9 +1296,15 @@ class Executor(object): ...@@ -1296,9 +1296,15 @@ class Executor(object):
use_program_cache=use_program_cache) use_program_cache=use_program_cache)
if isinstance(program, Program) and program._heter_pipeline_opt: if isinstance(program, Program) and program._heter_pipeline_opt:
## change default executor
heter_place = program._heter_pipeline_opt["heter_place"]
heter_place = framework._get_paddle_place(heter_place)
p = core.Place()
p.set_place(heter_place)
self._default_executor = core.Executor(p)
# TODO(zhangminxu): support heterps pipeline training using exe.run
if "startup_program" in program._heter_pipeline_opt: if "startup_program" in program._heter_pipeline_opt:
program = program._heter_pipeline_opt["startup_program"] program = program._heter_pipeline_opt["startup_program"]
# TODO(zhangminxu): support heterps pipeline training using exe.run
if isinstance(program, Program) and \ if isinstance(program, Program) and \
len(program.global_block().ops) == 0: len(program.global_block().ops) == 0:
...@@ -1704,6 +1710,7 @@ class Executor(object): ...@@ -1704,6 +1710,7 @@ class Executor(object):
dataset.set_use_var(data_vars) dataset.set_use_var(data_vars)
elif program._heter_pipeline_opt is not None: elif program._heter_pipeline_opt is not None:
stage_id = program._heter_pipeline_opt["pipeline_stage"] stage_id = program._heter_pipeline_opt["pipeline_stage"]
heter_place = program._heter_pipeline_opt["heter_place"]
if stage_id != 0: if stage_id != 0:
import paddle import paddle
if dataset is not None: if dataset is not None:
...@@ -1729,6 +1736,11 @@ class Executor(object): ...@@ -1729,6 +1736,11 @@ class Executor(object):
if dataset is None: if dataset is None:
raise RuntimeError( raise RuntimeError(
"dataset is need and should be initialized") "dataset is need and should be initialized")
## change default executor
heter_place = framework._get_paddle_place(heter_place)
p = core.Place()
p.set_place(heter_place)
self._default_executor = core.Executor(p)
else: else:
if dataset is None: if dataset is None:
raise RuntimeError("dataset is need and should be initialized") raise RuntimeError("dataset is need and should be initialized")
......
...@@ -148,9 +148,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): ...@@ -148,9 +148,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
"section_program"] "section_program"]
print(real_program) print(real_program)
real_startup = fluid.default_startup_program()._heter_pipeline_opt[ exe.run(fluid.default_startup_program())
"startup_program"]
exe.run(real_startup)
fleet.init_worker() fleet.init_worker()
thread_num = int(os.getenv("CPU_NUM", 2)) thread_num = int(os.getenv("CPU_NUM", 2))
...@@ -185,7 +183,9 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): ...@@ -185,7 +183,9 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
def do_dataset_heter_training(self, fleet): def do_dataset_heter_training(self, fleet):
fleet.init_heter_worker() exe = fluid.Executor()
exe.run(fluid.default_startup_program())
fleet.init_worker()
real_program = fluid.default_main_program()._heter_pipeline_opt[ real_program = fluid.default_main_program()._heter_pipeline_opt[
"section_program"] "section_program"]
print(real_program) print(real_program)
...@@ -194,7 +194,13 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): ...@@ -194,7 +194,13 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
batch_size = 128 batch_size = 128
pass_start = time.time() pass_start = time.time()
fleet.run_heter_worker(dataset=None) exe.train_from_dataset(
program=fluid.default_main_program(),
fetch_list=[self.avg_cost],
fetch_info=["cost"],
print_period=2,
debug=int(os.getenv("Debug", "0")))
exe.close()
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
print("do_dataset_heter_training done. using time {}".format(pass_time)) print("do_dataset_heter_training done. using time {}".format(pass_time))
......
...@@ -39,7 +39,7 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase): ...@@ -39,7 +39,7 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast "FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "", "http_proxy": "",
"CPU_NUM": "3" "CPU_NUM": "2"
} }
required_envs.update(need_envs) required_envs.update(need_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册