提交 0d57ca46 编写于 作者: Y Yang Yang

nccl pass parallel_do test

上级 0815c0f1
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/framework/framework.pb.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/nccl/nccl_gpu_common.h" #include "paddle/operators/nccl/nccl_gpu_common.h"
...@@ -49,6 +50,22 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -49,6 +50,22 @@ class NCCLInitOp : public framework::OperatorBase {
} }
}; };
class NCCLInitOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Communicator").front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarDesc::NCCL_COM;
out_var.SetType(var_type);
}
};
class NCCLInitOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker) NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -214,7 +231,9 @@ Bcast the tensors. ...@@ -214,7 +231,9 @@ Bcast the tensors.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker,
ops::NCCLInitOpVarTypeInference,
ops::NCCLInitOpShapeInference);
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
ops::NCCLAllReduceOpMaker); ops::NCCLAllReduceOpMaker);
......
...@@ -47,8 +47,11 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -47,8 +47,11 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<LoDTensor>("X"); auto ins = ctx.MultiInput<LoDTensor>("X");
auto outs = ctx.MultiOutput<LoDTensor>("Out"); auto outs = ctx.MultiOutput<LoDTensor>("Out");
LOG(INFO) << "------------------";
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
LOG(INFO) << "------------------";
ncclRedOp_t reduction_op_ = ncclSum; ncclRedOp_t reduction_op_ = ncclSum;
LOG(INFO) << "------------------";
if (reduction == "ncclMin") { if (reduction == "ncclMin") {
reduction_op_ = ncclMin; reduction_op_ = ncclMin;
...@@ -62,14 +65,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -62,14 +65,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
PADDLE_THROW("Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
LOG(INFO) << "------------------";
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
LOG(INFO) << "------------------";
auto stream = ctx.cuda_device_context().stream(); auto stream = ctx.cuda_device_context().stream();
LOG(INFO) << "------------------";
// device id // device id
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
LOG(INFO) << "------------------";
int idx = comm->GetCommId(gpu_id); int idx = comm->GetCommId(gpu_id);
LOG(INFO) << "------------------";
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
VLOG(1) << "gpu : " VLOG(1) << "gpu : "
<< " invoke allreduce. send " << ins[i]->numel() << " recv " << " invoke allreduce. send " << ins[i]->numel() << " recv "
......
...@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs"; ...@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
static constexpr char kParallelScopes[] = "parallel_scopes"; static constexpr char kParallelScopes[] = "parallel_scopes";
static constexpr char kParallelBlock[] = "sub_block"; static constexpr char kParallelBlock[] = "sub_block";
static constexpr char kUseNCCL[] = "use_nccl";
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
...@@ -159,6 +160,7 @@ class ParallelDoOp : public framework::OperatorBase { ...@@ -159,6 +160,7 @@ class ParallelDoOp : public framework::OperatorBase {
} }
WaitOnPlaces(places); WaitOnPlaces(places);
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
std::vector<std::future<void>> workers; std::vector<std::future<void>> workers;
workers.reserve(places.size()); workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) { for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) {
...@@ -202,6 +204,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -202,6 +204,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs, "").AsDuplicable(); AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, ""); AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDesc *>(kParallelBlock, ""); AddAttr<framework::BlockDesc *>(kParallelBlock, "");
AddAttr<bool>(kUseNCCL, "true if we use nccl on backward")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
ParallelDo Operator. ParallelDo Operator.
)DOC"); )DOC");
...@@ -223,20 +227,22 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -223,20 +227,22 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) auto &sub_scopes = scope.FindVar(Input(kParallelScopes))
->Get<std::vector<framework::Scope *>>(); ->Get<std::vector<framework::Scope *>>();
auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>(); auto &places = scope.FindVar(Input(kPlaces))->Get<platform::PlaceList>();
// PADDLE_ENFORCE_EQ(places.size(), sub_scopes.size());
// feed output@grad // feed output@grad
SplitTensorAndMoveTensorToScopes( SplitTensorAndMoveTensorToScopes(
scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes), scope, const_cast<std::vector<framework::Scope *> *>(&sub_scopes),
places, Inputs(framework::GradVarName(kOutputs))); places, Inputs(framework::GradVarName(kOutputs)));
WaitOnPlaces(places); WaitOnPlaces(places);
LOG(INFO) << "places " << places.size();
// exe run // exe run
std::vector<std::future<void>> workers; std::vector<std::future<void>> workers;
for (size_t i = 0; i < sub_scopes.size(); ++i) { for (size_t i = 0; i < sub_scopes.size(); ++i) {
auto &place = places[i]; auto &place = places[i];
auto *cur_scope = sub_scopes[i]; auto *cur_scope = sub_scopes[i];
LOG(INFO) << place;
// execute // execute
workers.emplace_back(framework::Async([program, cur_scope, place, block] { workers.emplace_back(framework::Async([program, cur_scope, place, block] {
...@@ -245,12 +251,26 @@ class ParallelDoGradOp : public framework::OperatorBase { ...@@ -245,12 +251,26 @@ class ParallelDoGradOp : public framework::OperatorBase {
false /*create_local_scope*/); false /*create_local_scope*/);
})); }));
} }
LOG(INFO) << "places " << places.size();
for (auto &worker : workers) { for (auto &worker : workers) {
worker.wait(); worker.wait();
} }
WaitOnPlaces(places); WaitOnPlaces(places);
AccumulateGrad(scope, place, sub_scopes, places); // NCCL allreduce op will be added by backward,
// so no need to explicitly accumulate grad
if (!(Attr<bool>(kUseNCCL))) {
AccumulateGrad(scope, place, sub_scopes, places);
} else {
for (auto &place : places) {
PADDLE_ENFORCE(platform::is_gpu_place(place),
"NCCL only supports cuda place");
}
}
for (auto &s : Outputs(framework::GradVarName(kParameters))) {
CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s));
}
WaitOnPlaces(places);
} }
void AccumulateGrad(const framework::Scope &scope, void AccumulateGrad(const framework::Scope &scope,
......
...@@ -218,7 +218,7 @@ def _callback_lookup_(op): ...@@ -218,7 +218,7 @@ def _callback_lookup_(op):
:param op: :param op:
:return: callback function :return: callback function
""" """
if op.type == 'parallel_do': if op.type == 'parallel_do' and op.attr('use_nccl'):
param_names = set(op.input('parameters')) param_names = set(op.input('parameters'))
param_grad_names = [n + "@GRAD" for n in param_names] param_grad_names = [n + "@GRAD" for n in param_names]
...@@ -229,18 +229,25 @@ def _callback_lookup_(op): ...@@ -229,18 +229,25 @@ def _callback_lookup_(op):
def __call__(self, block, context): def __call__(self, block, context):
if not self.has_inserted_nccl_init: if not self.has_inserted_nccl_init:
global_block = block.program.global_block() # global_block = block.program.global_block()
op_desc = global_block.desc.append_op() # op_desc = global_block.desc.append_op()
var_desc = global_block.desc.var('nccl_com') # var_desc = global_block.desc.var('nccl_com__do_not_change_')
var_desc.set_type(core.VarDesc.VarType.NCCL_COM) # var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
self.nccl_com = global_block.create_var( # self.nccl_com = global_block.create_var(
name='nccl_com', type=core.VarDesc.VarType.NCCL_COM) # name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
framework.Operator( # framework.Operator(
global_block, # global_block,
type='ncclInit', # type='ncclInit',
desc=op_desc, # desc=op_desc,
inputs={}, # inputs={},
outputs={'Communicator': [self.nccl_com]}) # outputs={'Communicator': [self.nccl_com]})
op_desc = _create_op_desc_(
"ncclInit", {},
{"Communicator": ['nccl_com__do_not_change_']}, {})
# block.desc.append_op().copy_from(op_desc)
print(serialize_op_decs(op_desc))
block.program.global_block().desc.append_op().copy_from(
op_desc)
self.has_inserted_nccl_init = True self.has_inserted_nccl_init = True
current_op_desc = context["__current_op_desc__"] current_op_desc = context["__current_op_desc__"]
...@@ -263,7 +270,8 @@ def _callback_lookup_(op): ...@@ -263,7 +270,8 @@ def _callback_lookup_(op):
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"ncclAllReduce", { "ncclAllReduce", {
"X": [o_argu], "X": [o_argu],
"Communicator": ['nccl_com_0'] "Communicator":
['nccl_com__do_not_change_']
}, {"Out": [allreduce_out_name]}, }, {"Out": [allreduce_out_name]},
{"reduction": "ncclSum"}) {"reduction": "ncclSum"})
block.desc.append_op().copy_from(op_desc) block.desc.append_op().copy_from(op_desc)
...@@ -375,10 +383,11 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -375,10 +383,11 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
continue continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block) grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
# infer_shape and infer_type # infer_shape and infer_type
if op_desc.type() == 'ncclInit':
continue
op_desc.infer_var_type(block.desc) op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc) op_desc.infer_shape(block.desc)
# ncclInit dones't need to set data_type
if op_desc.type() == 'ncclInit':
continue
for arg in op_desc.output_arg_names(): for arg in op_desc.output_arg_names():
if arg in new_vars: if arg in new_vars:
_infer_var_data_type_(arg, block) _infer_var_data_type_(arg, block)
......
...@@ -237,12 +237,13 @@ class ParallelDo(object): ...@@ -237,12 +237,13 @@ class ParallelDo(object):
ParallelDo class is used to create a ParallelDo. ParallelDo class is used to create a ParallelDo.
""" """
def __init__(self, places, name=None): def __init__(self, places, use_nccl=False, name=None):
self.helper = LayerHelper("parallel_do", name=name) self.helper = LayerHelper("parallel_do", name=name)
self.inputs = [] self.inputs = []
self.places = places self.places = places
self.outputs = [] self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK self.status = StaticRNN.BEFORE_RNN_BLOCK
self.use_nccl = use_nccl
def do(self): def do(self):
return BlockGuardWithCompletion(self) return BlockGuardWithCompletion(self)
...@@ -325,7 +326,8 @@ class ParallelDo(object): ...@@ -325,7 +326,8 @@ class ParallelDo(object):
}, },
outputs={'outputs': outputs, outputs={'outputs': outputs,
'parallel_scopes': [step_scope]}, 'parallel_scopes': [step_scope]},
attrs={'sub_block': current_block}) attrs={'sub_block': current_block,
'use_nccl': self.use_nccl})
class BlockGuardWithCompletion(BlockGuard): class BlockGuardWithCompletion(BlockGuard):
......
...@@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase):
fetch=fetch, fetch=fetch,
place=gpu, place=gpu,
use_parallel=True) use_parallel=True)
result_gpu_nccl = self._run_test_impl_(
callback=callback,
feed=feed,
fetch=fetch,
place=gpu,
use_parallel=True,
use_nccl=True)
self._assert_same_(fetch, result_cpu, result_cpu_parallel, self._assert_same_(fetch, result_cpu, result_cpu_parallel,
result_gpu, result_gpu_parallel) result_gpu, result_gpu_parallel, result_gpu_nccl)
else: else:
self._assert_same_(fetch, result_cpu, result_cpu_parallel) self._assert_same_(fetch, result_cpu, result_cpu_parallel)
def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False): def _run_test_impl_(self,
callback,
feed,
fetch,
place,
use_parallel=False,
use_nccl=False):
""" """
Run a single test, returns the fetch values Run a single test, returns the fetch values
Args: Args:
...@@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase):
# Automatically insert parallel do if use_parallel = True # Automatically insert parallel do if use_parallel = True
if use_parallel: if use_parallel:
places = fluid.layers.get_places() places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places) pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
data = next(generator) data = next(generator)
if isinstance(data, fluid.Variable): if isinstance(data, fluid.Variable):
...@@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase): ...@@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase):
""" """
def _impl_(a, b, fetch_id, item_id): def _impl_(a, b, fetch_id, item_id):
item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU'] item_str = [
'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL'
]
flag = numpy.allclose(a, b, rtol=0.1) flag = numpy.allclose(a, b, rtol=0.1)
self.assertTrue(flag, "The {0} are different in {1}".format( self.assertTrue(flag, "The {0} are different in {1}".format(
fetch[fetch_id], item_str[item_id])) fetch[fetch_id], item_str[item_id]))
...@@ -157,18 +172,10 @@ class ParallelOpTest(BaseParallelForTest): ...@@ -157,18 +172,10 @@ class ParallelOpTest(BaseParallelForTest):
loss = fluid.layers.mean(x=hidden) loss = fluid.layers.mean(x=hidden)
yield loss yield loss
def test_simple_fc(self):
self.run_test(
callback=self.__network__,
feed={
'img': numpy.random.random(size=(51, 784)).astype('float32')
},
fetch=['fc1.w@GRAD'])
def test_fc_with_tiny_data(self): def test_fc_with_tiny_data(self):
self.run_test( self.run_test(
callback=self.__network__, callback=self.__network__,
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')}, feed={'img': numpy.random.random(size=(8, 784)).astype('float32')},
fetch=['fc1.w@GRAD']) fetch=['fc1.w@GRAD'])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册