提交 d3cda7f7 编写于 作者: C chengmo

fix

上级 dbbcc43c
...@@ -97,7 +97,7 @@ message AsyncConfig { ...@@ -97,7 +97,7 @@ message AsyncConfig {
optional int32 thread_pool_size = 6 [ default = 1 ]; optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ]; optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ]; optional bool runtime_split_send_recv = 8 [ default = false ];
optional string worker_device = 9 [ default = 'cpu' ]; optional string heter_worker_device = 9 [ default = 'cpu' ];
} }
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
......
...@@ -511,13 +511,6 @@ class RoleMakerBase(object): ...@@ -511,13 +511,6 @@ class RoleMakerBase(object):
return self._heter_trainer_endpoints[(self._current_id) % return self._heter_trainer_endpoints[(self._current_id) %
self._heter_worker_num()] self._heter_worker_num()]
def _get_heter_worker_device(self):
"""
Returns:
string: heter_trainer's device of current node, e.g: CPU/GPU/XPU
"""
return self._heter_trainer_device.upper()
class PaddleCloudRoleMaker(RoleMakerBase): class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self, is_collective=False, **kwargs): def __init__(self, is_collective=False, **kwargs):
...@@ -696,8 +689,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -696,8 +689,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
# For heter parameter server env setting # For heter parameter server env setting
heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST", heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST",
"") "")
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE", "") if heter_trainer_eplist != "":
if heter_trainer_eplist != "" and heter_trainer_device != "":
try: try:
heter_trainer_eplist = os.environ[ heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",") "PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
...@@ -708,12 +700,6 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -708,12 +700,6 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self._is_heter_parameter_server_mode = True self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist) heter_trainers_num = len(heter_trainer_eplist)
current_node_device = heter_trainer_device.upper()
if current_node_device not in ["CPU", "GPU", "XPU"]:
raise ValueError(
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)".
format(heter_trainer_device))
self._heter_trainer_device = current_node_device
else: else:
self._is_heter_parameter_server_mode = False self._is_heter_parameter_server_mode = False
heter_trainers_num = 0 heter_trainers_num = 0
......
...@@ -91,14 +91,6 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -91,14 +91,6 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
''') ''')
base_group = parser.add_argument_group("Base Parameters") base_group = parser.add_argument_group("Base Parameters")
base_group.add_argument(
"-d",
"--distributed_mode",
type=str,
choices=["collective", "ps", "ps_heter", "ps_gpu", ""],
default="",
help="Distributed running mode: collective/ps/ps_gpu/ps_heter")
base_group.add_argument( base_group.add_argument(
"--log_dir", "--log_dir",
type=str, type=str,
...@@ -150,13 +142,6 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -150,13 +142,6 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
ps_group.add_argument( ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers") "--heter_worker_num", type=int, help="number of heter_workers")
ps_group.add_argument(
"--heter_worker_device",
type=str,
default="gpu",
choices=["gpu", "xpu"],
help="heter worker device")
return parser.parse_args() return parser.parse_args()
...@@ -244,34 +229,37 @@ def launch_collective(args): ...@@ -244,34 +229,37 @@ def launch_collective(args):
shutil.rmtree(gloo_rendezvous_dir) shutil.rmtree(gloo_rendezvous_dir)
def launch_ps(args): def launch_ps(args, distribute_mode):
cloud_flag = cloud_utils.use_paddlecloud() cloud_flag = cloud_utils.use_paddlecloud()
# for ps-cpu on paddlecloud # for ps-cpu on paddlecloud
direct_start_mode = ["ps", ""] if cloud_flag and distribute_mode == DistributeMode.PS:
if cloud_flag and (args.distributed_mode in direct_start_mode):
direct_start(args) direct_start(args)
return return
elif cloud_flag and args.distributed_mode == "ps_heter": elif cloud_flag and distribute_mode == DistributeMode.PS_HETER:
cloud_ps_heter_env_set(args) cloud_ps_heter_env_set(args)
args.trainers = os.getenv("PADDLE_TRAINER_ENDPOINTS") args.trainers = os.getenv("PADDLE_TRAINER_ENDPOINTS")
args.workers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST") args.workers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST")
args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST") args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST")
ps_launcher = ParameterServerLauncher(args) ps_launcher = ParameterServerLauncher(args, distribute_mode)
ps_launcher.start_ps(args) ps_launcher.start_ps()
return return
def launch(): def which_distributed_mode(args):
args = _parse_args()
logger = get_logger()
_print_arguments(args)
ps_args = [ ps_args = [
'--worker_num', '--server_num', '--heter_worker_num', '--servers', '--worker_num',
'--workers', '--heter_worrkers', 'heter_worker_device' '--server_num',
'--heter_worker_num',
'--servers',
'--workers',
'--heter_workers',
] ]
collective_args = ['--ips', '--gpus'] collective_args = ['--ips']
ps_heter_args = ["--heter_worker_num", "--heter_workers"]
has_ps_args = [ has_ps_args = [
ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1]) ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1])
] ]
...@@ -279,25 +267,45 @@ def launch(): ...@@ -279,25 +267,45 @@ def launch():
co_arg for co_arg in collective_args co_arg for co_arg in collective_args
if co_arg in " ".join(sys.argv[1:-1]) if co_arg in " ".join(sys.argv[1:-1])
] ]
assert (
len(has_ps_args) > 1 and len(has_collective_args) > 1
), "Only one mode(Collective or Parameter-Server ) can be selected at the same time, but more than one configuration was received."
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
cuda_device_num = fluid.core.get_cuda_device_count() cuda_device_num = fluid.core.get_cuda_device_count()
else: else:
cuda_device_num = 0 cuda_device_num = 0
ps_mode = ['ps', 'ps_gpu', 'ps_heter'] if len(has_ps_args) > 0:
if len(has_ps_args) > 0 or args.distributed_mode in ps_mode:
logger.info( logger.info(
"Run parameter-sever mode. pserver arguments:{}, cuda count:{}". "Run parameter-sever mode. pserver arguments:{}, cuda count:{}".
format(has_ps_args, cuda_device_num)) format(has_ps_args, cuda_device_num))
launch_ps(args) has_ps_heter_args = list(set(has_ps_args) & set(ps_heter_args))
if len(has_ps_heter_args) > 0:
return DistributeMode.PS_HETER
else:
return DistributeMode.PS
elif len(has_collective_args) > 0: elif len(has_collective_args) > 0:
logger.info("Run collective gpu mode. gpu arguments:{}, cuda count:{}". logger.info("Run collective gpu mode. gpu arguments:{}, cuda count:{}".
format(has_collective_args, cuda_device_num)) format(has_collective_args, cuda_device_num))
launch_collective(args) return DistributeMode.COLLECTIVE
else: else:
logger.warning( logger.warning(
"Not found distinct arguments. Default use gpu collective mode") "Not found distinct arguments. Default use gpu collective mode")
return DistributeMode.COLLECTIVE
def launch():
args = _parse_args()
logger = get_logger()
_print_arguments(args)
distribute_mode = which_distributed_mode(args)
if distribute_mode == DistributeMode.COLLECTIVE:
launch_collective(args) launch_collective(args)
else:
launch_ps(args, distribute_mode)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,6 +32,15 @@ logger = logging.getLogger("root") ...@@ -32,6 +32,15 @@ logger = logging.getLogger("root")
logger.propagate = False logger.propagate = False
class DistributeMode:
"""
There are various mode for fleetrun, each of them is designed for different model.
"""
COLLECTIVE = 0
PS = 1
PS_HETER = 2
class Cluster(object): class Cluster(object):
def __init__(self, hdfs): def __init__(self, hdfs):
self.job_server = None self.job_server = None
...@@ -616,7 +625,9 @@ def cloud_ps_heter_env_set(args): ...@@ -616,7 +625,9 @@ def cloud_ps_heter_env_set(args):
class ParameterServerLauncher(object): class ParameterServerLauncher(object):
def __init__(self, args): def __init__(self, args, distribute_mode):
self.args = args
self.distribute_mode = distribute_mode
self.server_num = 0 self.server_num = 0
self.worker_num = 0 self.worker_num = 0
self.heter_worker_num = 0 self.heter_worker_num = 0
...@@ -677,7 +688,7 @@ class ParameterServerLauncher(object): ...@@ -677,7 +688,7 @@ class ParameterServerLauncher(object):
self.worker_num = len(self.worker_endpoints.split(",")) self.worker_num = len(self.worker_endpoints.split(","))
# get heter worker envs # get heter worker envs
if args.distributed_mode == "ps_heter": if self.distribute_mode == DistributeMode.PS_HETER:
if args.heter_worker_num: if args.heter_worker_num:
self.heter_worker_num = args.heter_worker_num self.heter_worker_num = args.heter_worker_num
if args.heter_workers: if args.heter_workers:
...@@ -713,7 +724,7 @@ class ParameterServerLauncher(object): ...@@ -713,7 +724,7 @@ class ParameterServerLauncher(object):
] ]
self.node_ips = list( self.node_ips = list(
set(self.server_endpoints_ips + self.worker_endpoints_ips)) set(self.server_endpoints_ips + self.worker_endpoints_ips))
if args.distributed_mode == "ps_heter": if self.distribute_mode == DistributeMode.PS_HETER:
self.heter_worker_endpoints_ips = [ self.heter_worker_endpoints_ips = [
x.strip().split(":")[0] x.strip().split(":")[0]
for x in self.heter_worker_endpoints.split(",") for x in self.heter_worker_endpoints.split(",")
......
...@@ -198,16 +198,21 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -198,16 +198,21 @@ class ParameterServerRuntime(RuntimeBase):
warnings.warn("communicator has been initialized, skip") warnings.warn("communicator has been initialized, skip")
def _get_executor(self): def _get_executor(self):
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
if self.role_maker._get_heter_worker_device() == "GPU": heter_worker_device = self.context["valid_strategy"].a_sync_configs[
"heter_worker_device"].upper()
if heter_worker_device == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id)) executor = Executor(fluid.CUDAPlace(gpu_id))
elif self.role_maker._get_heter_worker_device() == "XPU": elif heter_worker_device == "XPU":
xpu_id = int(os.getenv("FLAGS_selected_xpus", "0")) xpu_id = int(os.getenv("FLAGS_selected_xpus", "0"))
executor = Executor(fluid.XPUPlace(xpu_id)) executor = Executor(fluid.XPUPlace(xpu_id))
elif heter_worker_device == "CPU":
fluid.Executor(fluid.CPUPlace())
else: else:
raise ValueError("Not Support Device {}".format( raise ValueError("Heter Worker Not Support Device {}".format(
self.role_maker._get_heter_worker_device())) heter_worker_device))
else: else:
executor = fluid.Executor(fluid.CPUPlace()) executor = fluid.Executor(fluid.CPUPlace())
return executor return executor
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册