未验证 提交 8caee2ad 编写于 作者: L lilong12 提交者: GitHub

【paddle.fleet】add the support for multi-node training for pipeline (#25907)

* add the support for multi-node training
上级 bf2db646
......@@ -52,10 +52,12 @@ class CCommInitOp : public framework::OperatorBase {
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
int device_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id");
}
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id,
BOOST_GET_CONST(platform::CUDAPlace, place).device, rid);
nccl_id, nranks, rank_id, device_id, rid);
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
......@@ -74,6 +76,11 @@ Initialize collective communicatoin context within this trainer
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training.");
AddAttr<int>("device_id",
"(int) The deivce_id on which to initialize the communicator."
"Now, you only have to set this attr manually for pipeline "
"training. Otherwise, make it as default.")
.SetDefault(-1);
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
......
......@@ -11,12 +11,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
from paddle.fluid.optimizer import PipelineOptimizer as PO
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
__all__ = ["PipelineOptimizer"]
class PipelineHelper(CollectiveHelper):
def __init__(self, role_maker, nrings=1, wait_port='6174'):
super(PipelineHelper, self).__init__(role_maker, nrings, wait_port)
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
'device_id': OpRole.Forward
})
def _broadcast_params(self):
block = self.startup_program.global_block()
ring_id = 0
for param in block.iter_parameters():
if param.is_distributed:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
for ring_id in range(self.nrings):
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
class PipelineOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(PipelineOptimizer, self).__init__(optimizer)
......@@ -40,15 +112,6 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy.pipeline = False
dist_strategy.pipeline_configs = {"micro_batch": 1}
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
return self.wrapped_opt.backward(loss, startup_program, parameter_list,
no_grad_set, callbacks)
def minimize_impl(self,
loss,
startup_program=None,
......@@ -57,4 +120,105 @@ class PipelineOptimizer(MetaOptimizerBase):
optimize_ops, params_grads, prog_list = \
self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
if self.role_maker.worker_num() == 1:
return optimize_ops, params_grads
endpoints = self.role_maker.get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker.worker_index()]
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
assert prog_list
self.main_program_list = prog_list
self.main_program = loss.block.program
nranks = len(endpoints)
self.nranks = nranks
self.nrings = len(self.main_program_list)
self.rank = self.role_maker.worker_index()
self.endpoints = endpoints
self.current_endpoint = current_endpoint
pipeline_helper = PipelineHelper(self.role_maker, nrings=self.nrings)
pipeline_helper.update_startup_program(self.startup_program)
self._transpile_main_program()
return optimize_ops, params_grads
def _transpile_main_program(self):
self._insert_loss_grad_ops()
for ring_id in range(self.nrings):
self._insert_allreduce_ops(ring_id)
def _insert_loss_grad_ops(self):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block = self.main_program_list[self.nrings - 1]['program'].global_block(
)
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / self.nranks,
OP_ROLE_KEY: OpRole.Backward
})
def _insert_allreduce_ops(self, ring_id):
block = self.main_program_list[ring_id]['program'].global_block()
origin_block = self.main_program.global_block()
grad = None
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
for i in range(0, len(op_role_var), 2):
param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
origin_param = origin_block.vars[op_role_var[i]]
if origin_param.is_distributed:
continue
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
if grad is None:
return
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
idx + ring_id,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
......@@ -22,6 +22,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_simple_dist_transpiler)
list(APPEND MIXED_DIST_TEST_OPS test_recv_save_op)
list(APPEND MIXED_DIST_TEST_OPS test_transpiler_ops)
list(APPEND MIXED_DIST_TEST_OPS test_launch)
list(APPEND MIXED_DIST_TEST_OPS test_c_comm_init_op)
list(APPEND MIXED_DIST_TEST_OPS test_launch_ps)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_async)
list(APPEND MIXED_DIST_TEST_OPS test_communicator_geo)
......@@ -403,6 +404,7 @@ if(WITH_DISTRIBUTE)
if(WITH_GPU)
# NOTE. test_launch only work in gpu collective mode
bash_test_modules(test_launch START_BASH test_launch.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
bash_test_modules(test_c_comm_init_op START_BASH test_c_comm_init_op.sh ENVS PADDLE_BINARY_DIR=${PADDLE_BINARY_DIR})
py_test_modules(test_fleet_checkpoint MODULES test_fleet_checkpoint)
endif()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fleet.base.private_helper_function import wait_server_ready
class TestCCommInitOp(unittest.TestCase):
def setUp(self):
self.endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS").split(',')
self.current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
self.nranks = len(self.endpoints)
self.rank = self.endpoints.index(self.current_endpoint)
self.gpu_id = int(os.getenv("FLAGS_selected_gpus"))
self.place = fluid.CUDAPlace(self.gpu_id)
self.exe = fluid.Executor(self.place)
self.endpoints.remove(self.current_endpoint)
self.other_endpoints = self.endpoints
if self.rank == 0:
wait_server_ready(self.other_endpoints)
def test_specifying_devices(self):
program = fluid.Program()
block = program.global_block()
nccl_id_var = block.create_var(
name=fluid.unique_name.generate('nccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': self.rank,
'endpoint': self.current_endpoint,
'other_endpoints': self.other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': self.nranks,
'rank': self.rank,
'ring_id': 0,
'device_id': self.gpu_id
})
self.exe.run(program)
if __name__ == "__main__":
unittest.main()
#!/bin/bash
set -e
# use default values
# FIXME: random fails on Unknown command lines -c (or -m).
launch_py=${PADDLE_BINARY_DIR}/python/paddle/distributed/launch.py
CUDA_VISIBLE_DEVICES=0,1 python ${launch_py} c_comm_init_op.py
......@@ -19,7 +19,9 @@ import os
class TestFleetMetaOptimizer(unittest.TestCase):
def setUp(self):
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001"
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
def test_pipeline_optimizer(self):
import paddle.fleet as fleet
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册