未验证 提交 0b8067c0 编写于 作者: W Wu Yi 提交者: GitHub

fix dist train reduce mode (#13068)

* fix dist train reduce mode

* fix previous fix
上级 823c4f87
......@@ -744,7 +744,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.emplace(varname, op_dev_id);
}
} else {
PADDLE_ENFORCE(
PADDLE_THROW(
"the distribute training related op should be in [split_byref, "
"concat].");
}
......
......@@ -82,8 +82,18 @@ class TestDistRunnerBase(object):
strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1
strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
exe = fluid.ParallelExecutor(
True, loss_name=avg_cost.name, exec_strategy=strategy)
True,
loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra)
feed_var_list = [
var for var in trainer_prog.global_block().vars.values()
......@@ -123,6 +133,7 @@ def runtime_main(test_class):
'--current_endpoint', type=str, required=False, default="")
parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
args = parser.parse_args()
......@@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase):
self._python_interp = "python"
self._sync_mode = True
self._mem_opt = False
self._use_reduce = False
self._setup_config()
def start_pserver(self, model_file, check_error_log):
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s"
sync_mode_str = "--sync_mode" if self._sync_mode else ""
mem_opt_str = "--mem_opt" if self._mem_opt else ""
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
ps0_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers, sync_mode_str, mem_opt_str)
self._trainers)
ps1_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
self._trainers, sync_mode_str, mem_opt_str)
self._trainers)
if self._sync_mode:
ps0_cmd += " --sync_mode"
ps1_cmd += " --sync_mode"
if self._mem_opt:
ps0_cmd += " --mem_opt"
ps1_cmd += " --mem_opt"
ps0_pipe = subprocess.PIPE
ps1_pipe = subprocess.PIPE
......@@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase):
self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s"
sync_mode_str = "--sync_mode" if self._sync_mode else ""
mem_opt_str = "--mem_opt" if self._mem_opt else ""
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
tr0_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints,
0, ps0_ep,
self._trainers, sync_mode_str, mem_opt_str)
0, ps0_ep, self._trainers)
tr1_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints,
1, ps1_ep,
self._trainers, sync_mode_str, mem_opt_str)
1, ps1_ep, self._trainers)
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._mem_opt:
tr0_cmd += " --mem_opt"
tr1_cmd += " --mem_opt"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
......@@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase):
# FIXME: use terminate() instead of sigkill.
os.kill(ps0.pid, signal.SIGKILL)
os.kill(ps1.pid, signal.SIGKILL)
ps0.terminate()
ps1.terminate()
ps0.wait()
ps1.wait()
FNULL.close()
......
......@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
class TestDistMnist2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=1e-7)
......@@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
class TestDistMnistAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._use_reduce = False
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
# class TestDistMnist2x2ReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = True
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=1e-7)
# class TestDistMnistAsyncReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = False
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=200)
if __name__ == "__main__":
unittest.main()
......@@ -273,6 +273,10 @@ class DistributeTranspiler(object):
name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor
# will use op_role_var to get expected device place to run this op.
program.global_block()._insert_op(
index=index + 1,
type="send",
......@@ -281,8 +285,10 @@ class DistributeTranspiler(object):
attrs={
"epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[self.grad_name_to_param_name[grad_varname], grad_varname],
OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode,
})
for _, var in enumerate(splited_vars):
......@@ -326,6 +332,15 @@ class DistributeTranspiler(object):
recv_dep_in = grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
orig_grad_name = self.param_name_to_grad_name[param_varname]
recv_op_role_var_name = orig_grad_name
splited_trainer_grad = self.grad_var_mapping[orig_grad_name]
if len(splited_trainer_grad) == 1:
recv_op_role_var_name = splited_trainer_grad[0].name
program.global_block().append_op(
type="recv",
inputs={"X": [recv_dep_in]},
......@@ -333,10 +348,8 @@ class DistributeTranspiler(object):
attrs={
"epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [
param_varname,
self.param_name_to_grad_name[param_varname]
],
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name],
"sync_mode": not self.sync_mode
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册