提交 51898800 编写于 作者: M MrChengmo

fix

上级 291e1594
......@@ -97,6 +97,7 @@ message AsyncConfig {
optional int32 thread_pool_size = 6 [ default = 1 ];
optional int32 send_wait_times = 7 [ default = 1 ];
optional bool runtime_split_send_recv = 8 [ default = false ];
optional string heter_worker_device = 9 [ default = 'cpu' ];
}
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; }
......
......@@ -268,9 +268,10 @@ def which_distributed_mode(args):
if co_arg in " ".join(sys.argv[1:-1])
]
assert (
len(has_ps_args) > 1 and len(has_collective_args) > 1
), "Only one mode(Collective or Parameter-Server ) can be selected at the same time, but more than one configuration was received."
if len(has_ps_args) > 1 and len(has_collective_args) > 1:
raise ValueError(
"Only one mode(Collective or Parameter-Server ) can be selected at the same time, but more than one configuration was received."
)
if fluid.core.is_compiled_with_cuda():
cuda_device_num = fluid.core.get_cuda_device_count()
......
......@@ -610,7 +610,6 @@ def cloud_ps_heter_env_set(args):
assert trainers_num != 0
environs["PADDLE_TRAINERS_NUM"] = trainers_num
environs["TRAINERS_NUM"] = trainers_num
environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_worker_device
# hard code for paddlecloud custom-framework
environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints
......@@ -754,7 +753,7 @@ class ParameterServerLauncher(object):
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}".
format(self.node_ips, self.current_node_ip, self.node_rank))
def start_ps(self, args):
def start_ps(self):
cluster = Cluster(hdfs=None)
server_rank = 0
worker_rank = 0
......@@ -799,13 +798,13 @@ class ParameterServerLauncher(object):
self.cmds = {"worker": [], "server": [], "heter_worker": []}
self.log_fns = {"worker": [], "server": [], "heter_worker": []}
self.start_pod_server(args, pod)
self.start_pod_worker(args, pod)
self.start_pod_heter_worker(args, pod)
self.start_pod_server(self.args, pod)
self.start_pod_worker(self.args, pod)
self.start_pod_heter_worker(self.args, pod)
logger.info(
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*".
format(args.log_dir, args.log_dir, args.log_dir))
format(self.args.log_dir, self.args.log_dir, self.args.log_dir))
# 4. wait for finish training
if len(self.procs["worker"]) > 0:
......@@ -855,7 +854,6 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
......@@ -905,10 +903,10 @@ class ParameterServerLauncher(object):
heter_device_num = 0
device_list = []
if args.heter_worker_device == "gpu":
if fluid.core.is_compiled_with_cuda():
device_list = get_gpus(args.gpus)
heter_device_num = len(device_list)
elif args.heter_worker_device == "xpu":
elif fluid.core.is_compiled_with_xpu():
heter_device_num = fluid.core.get_xpu_device_count()
device_list = [str(x) for x in range(0, heter_device_num)]
......@@ -920,7 +918,6 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1",
......@@ -972,10 +969,10 @@ class ParameterServerLauncher(object):
heter_device_num = 0
device_list = []
if args.heter_worker_device == "gpu":
if fluid.core.is_compiled_with_cuda():
device_list = get_gpus(args.gpus)
heter_device_num = len(device_list)
elif args.heter_worker_device == "xpu":
elif fluid.core.is_compiled_with_xpu():
heter_device_num = fluid.core.get_xpu_device_count()
device_list = [str(x) for x in range(0, heter_device_num)]
assert heter_device_num != 0
......@@ -987,7 +984,6 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
"PADDLE_HETER_TRAINER_IP_PORT_LIST":
self.heter_worker_endpoints,
"PADDLE_HETER_TRAINER_DEVICE": args.heter_worker_device,
"PADDLE_PORT": cur_heter_worker.endpoint.split(":")[1],
"TRAINING_ROLE": "HETER_TRAINER",
"PADDLE_TRAINERS_NUM": str(self.worker_num),
......
......@@ -94,8 +94,8 @@ class ParameterServerRuntime(RuntimeBase):
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
......@@ -199,15 +199,20 @@ class ParameterServerRuntime(RuntimeBase):
def _get_executor(self):
if self.role_maker._is_heter_worker():
if self.role_maker._get_heter_worker_device() == "GPU":
dist_strategy = self.context["valid_strategy"]
heter_worker_device = dist_strategy.a_sync_configs[
"heter_worker_device"].upper()
if heter_worker_device == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id))
elif self.role_maker._get_heter_worker_device() == "XPU":
elif heter_worker_device == "XPU":
xpu_id = int(os.getenv("FLAGS_selected_xpus", "0"))
executor = Executor(fluid.XPUPlace(xpu_id))
elif heter_worker_device == "CPU":
executor = fluid.Executor(fluid.CPUPlace())
else:
raise ValueError("Not Support Device {}".format(
self.role_maker._get_heter_worker_device()))
raise ValueError("Heter Worker Not Support Device {}".format(
heter_worker_device))
else:
executor = fluid.Executor(fluid.CPUPlace())
return executor
......@@ -312,7 +317,7 @@ class ParameterServerRuntime(RuntimeBase):
opts = _get_optimize_ops(self.origin_main_program)
for op in opts:
if "Param" in op.input_names and \
"LearningRate" in op.input_names and op.input("Param")[0] == param_name:
"LearningRate" in op.input_names and op.input("Param")[0] == param_name:
return op
def _save_dense_params(self, executor, dirname, context, main_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册