From 1ae04a73dfeebe59794491bf23336f45d61fe8d0 Mon Sep 17 00:00:00 2001 From: Chengmo Date: Fri, 4 Sep 2020 17:37:37 +0800 Subject: [PATCH] fix Heter Ps multi thread (#26876) (#27016) * fix heter-ps multi thread --- .../fleet/parameter_server/ir/trainer_pass.py | 115 +++++++++++------- .../tests/unittests/ctr_dataset_reader.py | 2 +- .../tests/unittests/dist_fleet_heter_ctr.py | 2 +- .../unittests/test_dist_fleet_heter_ctr.py | 36 +++++- 4 files changed, 109 insertions(+), 46 deletions(-) diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 5e6b8ca6399..4543af9820e 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -441,7 +442,23 @@ def find_heter_ops(program, default_device="cpu"): def create_heter_program(program, config, heter_program, heter_ops, block_var_detail, current_device): - # add heter op + + # This function mainly includes the following contents: + # 1. For every heter block: + # a) copy heter device op from origin program + # b) create variables which belong to heter op: + # -> if variable is persistable, clone it in global_scope + # -> if variable is temp, create it in heter block + # c) create communicate related op as follow: + # joint_var.0_1 -> slice -> reshape -> origin_var + # origin_var -> origin_program + # reshape -> concat -> joint_var.1_2 + # d) copy send op from origin program for var@grad which loacted in current heter block + # e) re-check every op in current blcok if its device is not current heter devie + # 2. Create send op for step counter in last heter-block + # 3. Create Listen&Serv OP for distributed training + # 4. update CompileTimeStrategy for heter_program + optimizer_block = [] grad_to_block_id = [] send_grad_var_list = [] @@ -453,17 +470,10 @@ def create_heter_program(program, config, heter_program, heter_ops, for _, op in enumerate(heter_block_ops): block_append_op(heter_program, program, heter_block, op) - # add relate variables - inputs = _get_input_map_from_op(program.global_block().vars, op) - add_vars_by_op_map(inputs, heter_program) - - outputs = _get_output_map_from_op(program.global_block().vars, op) - add_vars_by_op_map(outputs, heter_program) - entrance_vars = block_var_detail[index]["entrance"] - add_vars_by_var_list(entrance_vars, program, heter_program) + add_vars_by_var_list(entrance_vars, program, heter_program, heter_block) exit_vars = block_var_detail[index]["exit"] - add_vars_by_var_list(exit_vars, program, heter_program) + add_vars_by_var_list(exit_vars, program, heter_program, heter_block) comm_info = get_communicate_var_info(program, index, entrance_vars, exit_vars) @@ -471,13 +481,13 @@ def create_heter_program(program, config, heter_program, heter_ops, grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str( heter_block.idx)) - # create slice op first_op_index = 0 get_type_var_name = comm_info["input_var_reshape_name"][0].split( ".input_reshape@Heter")[0] - get_type_var = heter_program.global_block().vars[get_type_var_name] + get_type_var = heter_block.vars[get_type_var_name] + # create slice op insert_recv_slice_op( heter_program, heter_block, first_op_index, comm_info["block_input_var_name"], @@ -487,6 +497,13 @@ def create_heter_program(program, config, heter_program, heter_ops, for i in range(len(comm_info["input_var_reshape_dim"])) ]) first_op_index += len(comm_info["input_var_reshape_dim"]) + + heter_program.global_block().create_var( + name=comm_info["block_input_var_name"], + shape=(-1, sum(comm_info["input_var_reshape_dim"])), + dtype=get_type_var.dtype, + type=get_type_var.type) + # create reshape op for i in range(len(comm_info["input_var_reshape_name"])): var_name = entrance_vars[i] @@ -514,13 +531,14 @@ def create_heter_program(program, config, heter_program, heter_ops, comm_info["block_output_var_name"], [-1, sum(comm_info["output_var_reshape_dim"])]) check_op_device(heter_block, current_device) + + # add send op send_grad_var_list = send_grad_var_list + add_heter_send_op( program, heter_program, heter_block, block_var_detail[index]) # add step conter send_input_vars = [] dummy_output = [] - trainer_id = config.get_role_id() pserver_endpoints = config.get_ps_endpoints() optimizer_block[-1].append_op( type="send", @@ -555,7 +573,6 @@ def create_heter_program(program, config, heter_program, heter_ops, # append the listen_and_serv op heter_program.global_block().append_op( type="listen_and_serv", inputs={'X': []}, outputs={}, attrs=attrs) - check_heter_compile_time_strategy(program, config, send_grad_var_list) @@ -574,6 +591,16 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list): def create_trainer_program(program, config, heter_ops, block_var_detail): + # This function mainly includes the following contents: + # 1. For every heter block in origin program + # a) delete heter op and related variables + # b) add send&recv op + # c) add communicate ops as follows: + # origin_var -> reshape -> concat -> joint_var.0_1 + # send&recv op(send joint_var.0_1; recv joint_var.1_2) + # joint_var.1_2 -> slice -> reshape -> origin_var + # d) remove send op which related var@grad is not in trainer program + # 2. check every op's device for device in heter_ops.keys(): for heter_block_index in sorted(heter_ops[device]): replace_ops_by_communicate_op(program, config, heter_block_index, @@ -932,19 +959,19 @@ def insert_reshape_op(program, var_name, new_var_name, new_var_shape=None): - input_var = program.global_block().vars[var_name] + input_var = block.vars[var_name] - if new_var_name not in program.global_block().vars: - out = program.global_block().create_var( + if new_var_name not in block.vars: + out = block.create_var( name=new_var_name, shape=new_var_shape, dtype=input_var.dtype, type=input_var.type) else: - out = program.global_block().vars[new_var_name] + out = block.vars[new_var_name] new_var_shape = out.shape - x_shape = program.global_block().create_var( + x_shape = block.create_var( name="{}.xshape@Heter".format(var_name), dtype=input_var.dtype) block._insert_op( index=index, @@ -957,9 +984,7 @@ def insert_reshape_op(program, def insert_send_concat_op(program, block, index, var_name_list, new_var_name, new_var_shape): - input_var_list = [ - program.global_block().vars[var_name] for var_name in var_name_list - ] + input_var_list = [block.vars[var_name] for var_name in var_name_list] out = program.global_block().create_var( name=new_var_name, @@ -987,14 +1012,14 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, out_list = [] for i in range(len(new_var_name_list)): - if new_var_name_list[i] not in program.global_block().vars: - out = program.global_block().create_var( + if new_var_name_list[i] not in block.vars: + out = block.create_var( name=new_var_name_list[i], shape=new_var_shape_list[i], dtype=input_var.dtype, type=input_var.type) else: - out = program.global_block().vars[new_var_name_list[i]] + out = block.vars[new_var_name_list[i]] out_list.append(out) start_index = 0 @@ -1037,21 +1062,33 @@ def deleter_trainer_useless_var(program): def block_append_op(program, origin_program, block, op): - inputs = _get_input_map_from_op(origin_program.global_block().vars, op) + merge_ordereddict = origin_program.global_block().vars.copy() + merge_ordereddict.update(block.vars) + inputs = _get_input_map_from_op(merge_ordereddict, op) for key, varlist in six.iteritems(inputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: - if var.name not in program.global_block().vars: - program.global_block()._clone_variable(var) + if var.name not in program.global_block( + ).vars and var.name not in block.vars: + if var.persistable: + program.global_block()._clone_variable( + var, force_persistable=False) + else: + block._clone_variable(var, force_persistable=False) outputs = _get_output_map_from_op(origin_program.global_block().vars, op) for key, varlist in six.iteritems(outputs): if not isinstance(varlist, list): varlist = [varlist] for var in varlist: - if var.name not in program.global_block().vars: - program.global_block()._clone_variable(var) + if var.name not in program.global_block( + ).vars and var.name not in block.vars: + if var.persistable: + program.global_block()._clone_variable( + var, force_persistable=False) + else: + block._clone_variable(var, force_persistable=False) if "_grad" not in op.type: # for forward op @@ -1076,21 +1113,15 @@ def block_append_op(program, origin_program, block, op): block._sync_with_cpp() -def add_vars_by_op_map(var_map, program): - for key, varlist in six.iteritems(var_map): - if not isinstance(varlist, list): - varlist = [varlist] - for i in range(len(varlist)): - var = varlist[i] - if var.name not in program.global_block().vars: - program.global_block()._clone_variable(var) - - -def add_vars_by_var_list(var_name_list, origin_program, program): +def add_vars_by_var_list(var_name_list, origin_program, program, block): for var_name in var_name_list: if var_name not in program.global_block().vars: var = origin_program.global_block().vars[var_name] - program.global_block()._clone_variable(var) + if var.persistable: + program.global_block()._clone_variable( + var, force_persistable=False) + else: + block._clone_variable(var, force_persistable=False) def get_varlist_from_op_map(var_map): diff --git a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py index 863c001f226..15e98481c26 100644 --- a/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py +++ b/python/paddle/fluid/tests/unittests/ctr_dataset_reader.py @@ -153,7 +153,7 @@ def gen_fake_line(dnn_data_num=7, return line -def prepare_fake_data(file_nums=8, file_lines=1000): +def prepare_fake_data(file_nums=9, file_lines=1000): """ Create fake data with same type as avazu_ctr_data """ diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py index 0de898d6dde..7a4e7534f07 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py @@ -177,7 +177,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): fleet.init_worker() exe.run(fluid.default_startup_program()) - thread_num = 1 + thread_num = int(os.getenv("CPU_NUM", 2)) batch_size = 128 filelist = fleet_util.get_file_shard(train_file_list) print("filelist: {}".format(filelist)) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py index c3ffd50dc8d..02a739c060c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py @@ -36,13 +36,45 @@ class TestDistHeterDatasetAsync2x2(TestFleetHeterBase): "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "FLAGS_rpc_deadline": "5000", # 5sec to fail fast "http_proxy": "", - "CPU_NUM": "1" + "CPU_NUM": "3" } required_envs.update(need_envs) if check_error_log: - required_envs["GLOG_v"] = "4" + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + + tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) + + def test_dist_train(self): + self.check_with_place( + "dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True) + + +class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase): + def _setup_config(self): + self._mode = "async" + self._reader = "pyreader" + + def check_with_place(self, + model_file, + delta=1e-3, + check_error_log=False, + need_envs={}): + required_envs = { + "PATH": os.getenv("PATH", ""), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "FLAGS_rpc_deadline": "5000", # 5sec to fail fast + "http_proxy": "", + "CPU_NUM": "3" + } + + required_envs.update(need_envs) + + if check_error_log: + required_envs["GLOG_v"] = "3" required_envs["GLOG_logtostderr"] = "1" tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) -- GitLab