提交 03b45b69 编写于 作者: W Wu Yi 提交者: typhoonzero

fix dist train reduce mode (#13068)

* fix dist train reduce mode

* fix previous fix
上级 f517d01a
...@@ -744,7 +744,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -744,7 +744,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(
"the distribute training related op should be in [split_byref, " "the distribute training related op should be in [split_byref, "
"concat]."); "concat].");
} }
......
...@@ -76,8 +76,18 @@ class TestDistRunnerBase(object): ...@@ -76,8 +76,18 @@ class TestDistRunnerBase(object):
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1 strategy.num_threads = 1
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, loss_name=avg_cost.name, exec_strategy=strategy) True,
loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.values() var for var in trainer_prog.global_block().vars.values()
...@@ -106,16 +116,20 @@ def runtime_main(test_class): ...@@ -106,16 +116,20 @@ def runtime_main(test_class):
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
if len(sys.argv) != 7: parser = argparse.ArgumentParser(description='Run dist test.')
print( parser.add_argument(
"Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]" '--role', type=str, required=True, choices=['pserver', 'trainer'])
) parser.add_argument('--endpoints', type=str, required=False, default="")
role = sys.argv[1] parser.add_argument('--is_dist', action='store_true')
endpoints = sys.argv[2] parser.add_argument('--trainer_id', type=int, required=False, default=0)
trainer_id = int(sys.argv[3]) parser.add_argument('--trainers', type=int, required=False, default=1)
current_endpoint = sys.argv[4] parser.add_argument(
trainers = int(sys.argv[5]) '--current_endpoint', type=str, required=False, default="")
is_dist = True if sys.argv[6] == "TRUE" else False parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
args = parser.parse_args()
model = test_class() model = test_class()
if role == "pserver": if role == "pserver":
...@@ -135,16 +149,28 @@ class TestDistBase(unittest.TestCase): ...@@ -135,16 +149,28 @@ class TestDistBase(unittest.TestCase):
self._pservers = 2 self._pservers = 2
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124" self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
self._python_interp = "python" self._python_interp = "python"
self._sync_mode = True
self._mem_opt = False
self._use_reduce = False
self._setup_config()
def start_pserver(self, model_file, check_error_log): def start_pserver(self, model_file, check_error_log):
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps0_cmd = "%s %s pserver %s 0 %s %d TRUE" % \ ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
ps0_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep, (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers) self._trainers)
ps1_cmd = "%s %s pserver %s 0 %s %d TRUE" % \ ps1_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps1_ep, (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
self._trainers) self._trainers)
if self._sync_mode:
ps0_cmd += " --sync_mode"
ps1_cmd += " --sync_mode"
if self._mem_opt:
ps0_cmd += " --mem_opt"
ps1_cmd += " --mem_opt"
ps0_pipe = subprocess.PIPE ps0_pipe = subprocess.PIPE
ps1_pipe = subprocess.PIPE ps1_pipe = subprocess.PIPE
if check_error_log: if check_error_log:
...@@ -226,12 +252,23 @@ class TestDistBase(unittest.TestCase): ...@@ -226,12 +252,23 @@ class TestDistBase(unittest.TestCase):
self._wait_ps_ready(ps1.pid) self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr0_cmd = "%s %s trainer %s 0 %s %d TRUE" % \ tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
(self._python_interp, model_file, self._ps_endpoints, ps0_ep, tr0_cmd = tr_cmd % \
self._trainers) (self._python_interp, model_file, self._ps_endpoints,
tr1_cmd = "%s %s trainer %s 1 %s %d TRUE" % \ 0, ps0_ep, self._trainers)
(self._python_interp, model_file, self._ps_endpoints, ps1_ep, tr1_cmd = tr_cmd % \
self._trainers) (self._python_interp, model_file, self._ps_endpoints,
1, ps1_ep, self._trainers)
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._mem_opt:
tr0_cmd += " --mem_opt"
tr1_cmd += " --mem_opt"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
env0 = {"CUDA_VISIBLE_DEVICES": "0"} env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"} env1 = {"CUDA_VISIBLE_DEVICES": "1"}
...@@ -282,6 +319,10 @@ class TestDistBase(unittest.TestCase): ...@@ -282,6 +319,10 @@ class TestDistBase(unittest.TestCase):
# FIXME: use terminate() instead of sigkill. # FIXME: use terminate() instead of sigkill.
os.kill(ps0.pid, signal.SIGKILL) os.kill(ps0.pid, signal.SIGKILL)
os.kill(ps1.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL)
ps0.terminate()
ps1.terminate()
ps0.wait()
ps1.wait()
FNULL.close() FNULL.close()
self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta) self.assertAlmostEqual(local_first_loss, dist_first_loss, delta=delta)
......
...@@ -17,10 +17,51 @@ import unittest ...@@ -17,10 +17,51 @@ import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
class TestDistSeResneXt2x2(TestDistBase): class TestDistMnist2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=1e-7)
class TestDistMnist2x2WithMemopt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._mem_opt = True
def test_se_resnext(self): def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-7)
class TestDistMnistAsync(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._use_reduce = False
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
# class TestDistMnist2x2ReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = True
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=1e-7)
# class TestDistMnistAsyncReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = False
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=200)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -265,6 +265,10 @@ class DistributeTranspiler(object): ...@@ -265,6 +265,10 @@ class DistributeTranspiler(object):
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor
# will use op_role_var to get expected device place to run this op.
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
...@@ -273,8 +277,10 @@ class DistributeTranspiler(object): ...@@ -273,8 +277,10 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eplist, "epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME: [
[self.grad_name_to_param_name[grad_varname], grad_varname], self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode, "sync_mode": not self.sync_mode,
}) })
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
...@@ -318,6 +324,15 @@ class DistributeTranspiler(object): ...@@ -318,6 +324,15 @@ class DistributeTranspiler(object):
recv_dep_in = grad_name_to_send_dummy_out[ recv_dep_in = grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
orig_grad_name = self.param_name_to_grad_name[param_varname]
recv_op_role_var_name = orig_grad_name
splited_trainer_grad = self.grad_var_mapping[orig_grad_name]
if len(splited_trainer_grad) == 1:
recv_op_role_var_name = splited_trainer_grad[0].name
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
...@@ -325,10 +340,8 @@ class DistributeTranspiler(object): ...@@ -325,10 +340,8 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eps, "epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME:
param_varname, [param_varname, recv_op_role_var_name],
self.param_name_to_grad_name[param_varname]
],
"sync_mode": not self.sync_mode "sync_mode": not self.sync_mode
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册