diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 636f67e4a7a212f9547b625bb0bab0a550bcde96..6fd19e804afd674f292ffe2112988bf9d166f12a 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -55,11 +55,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::READER) { var->GetMutable(); + } else if (var_type == proto::VarType::NCCL_COM) { + // GetMutable will be called in ncclInit } else { PADDLE_THROW( "Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER]", + "LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]", var_type); } } @@ -120,14 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, for (auto& op_desc : block.AllOps()) { auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); - VLOG(4) << op->DebugStringEx(local_scope); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::RecordEvent record_event(op->Type(), pool.Get(place_)); + VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); op->Run(*local_scope, place_); - // Wait current device context. - VLOG(3) << op->DebugStringEx(local_scope); + if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: " << memory::memory_usage(place_); diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 724d9793e585c0af807bd582242dbe0b79281a4a..4eb18b4e4d685111d02387d5ab944146c9217e62 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -113,6 +113,7 @@ message VarType { PLACE_LIST = 14; READER = 15; CHANNEL = 16; + NCCL_COM = 17; } required Type type = 1; diff --git a/paddle/fluid/operators/nccl_op.cc b/paddle/fluid/operators/nccl_op.cc index 5ae50590dde4b0c2d769f09a1b7c152180e6984d..0994bba782b42be994ae479f4c9c4de5a2e384ed 100644 --- a/paddle/fluid/operators/nccl_op.cc +++ b/paddle/fluid/operators/nccl_op.cc @@ -14,10 +14,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" +#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" namespace paddle { namespace operators { +static constexpr char kParallelScopes[] = "parallel_scopes"; + // NCCLinitOp class NCCLInitOp : public framework::OperatorBase { public: @@ -29,11 +32,22 @@ class NCCLInitOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kParallelScopes)), + "Can not find variable '%s' in the scope.", + kParallelScopes); const auto &name = Output("Communicator"); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), "Can not find variable '%s' in the scope.", name); - std::vector gpus = Attr>("gpus"); - PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); + // A parallel do may not use all the gpus. For example, the batch size is 7 + // in the last batch while we have 8 gpu. In this case, parallel_do will + // create 7 parallel scopes, so should ncclInitOp create 7 gpu peers + auto ¶llel_scopes = scope.FindVar(Input(kParallelScopes)) + ->Get>(); + std::vector gpus(parallel_scopes.size()); + for (int i = 0; i < static_cast(parallel_scopes.size()); ++i) { + gpus[i] = i; + } + PADDLE_ENFORCE(!gpus.empty(), "NCCL init with 0 gpus."); if (scope.FindVar(name) == nullptr) { PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); @@ -45,17 +59,29 @@ class NCCLInitOp : public framework::OperatorBase { } }; +class NCCLInitOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const override { + auto out_var_name = op_desc.Output("Communicator").front(); + auto &out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto var_type = framework::proto::VarType::NCCL_COM; + out_var.SetType(var_type); + } +}; + +class NCCLInitOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override {} +}; + class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { public: NCCLInitOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput(kParallelScopes, "The working place of parallel do."); AddOutput("Communicator", "Create Communicator for communicating between gpus"); - AddAttr>("gpus", "(vector) GPU id lists"); - AddAttr("dtype", - "(int, default 5 (FP32)) " - "Output data type") - .SetDefault(framework::proto::VarType::FP32); AddComment(R"DOC( NCCLInit Operator. @@ -78,7 +104,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { ctx->HasInput("Communicator"), " Input(Communicator) of AllReduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), - " Input(X) of AllReduce op input should not be NULL"); + " Output(Out) of AllReduce op output should not be NULL"); auto x_dims = ctx->GetInputsDim("X"); @@ -215,7 +241,9 @@ Bcast the tensors. namespace ops = paddle::operators; REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, - paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); + paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker, + ops::NCCLInitOpVarTypeInference, + ops::NCCLInitOpShapeInference); REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, ops::NCCLAllReduceOpMaker); diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index b21f9937ef5db13cc612668efccd93c3e6f78c48..6436efe42f1eb69f217aeba87dce89ea496edd84 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs"; static constexpr char kParallelScopes[] = "parallel_scopes"; static constexpr char kParallelBlock[] = "sub_block"; +static constexpr char kUseNCCL[] = "use_nccl"; using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; @@ -194,6 +195,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker { AddOutput(kOutputs, "").AsDuplicable(); AddOutput(kParallelScopes, ""); AddAttr(kParallelBlock, ""); + AddAttr(kUseNCCL, "true if we use nccl on backward") + .SetDefault(false); AddComment(R"DOC( ParallelDo Operator. )DOC"); @@ -216,7 +219,6 @@ class ParallelDoGradOp : public framework::OperatorBase { auto &sub_scopes = scope.FindVar(Input(kParallelScopes)) ->Get>(); - auto &places = scope.FindVar(Input(kPlaces))->Get(); // feed output@grad @@ -243,7 +245,24 @@ class ParallelDoGradOp : public framework::OperatorBase { } WaitOnPlaces(places); - AccumulateGrad(scope, place, sub_scopes, places); + // NCCL allreduce op will be added by backward, + // so no need to explicitly accumulate grad + if (!(Attr(kUseNCCL))) { + AccumulateGrad(scope, place, sub_scopes, places); + } else { + for (auto &place : places) { + PADDLE_ENFORCE(platform::is_gpu_place(place), + "NCCL only supports cuda place"); + } + } + for (auto &s : Outputs(framework::GradVarName(kParameters))) { + if (s == "@EMPTY@") { + continue; + } + VLOG(3) << "Moving " << s; + CopyOrShare(*sub_scopes[0]->FindVar(s), place, scope.FindVar(s)); + } + WaitOnPlaces(places); } void AccumulateGrad(const framework::Scope &scope, @@ -251,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase { const std::vector &sub_scopes, const platform::PlaceList &places) const { for (auto &s : Outputs(framework::GradVarName(kParameters))) { + if (s == "@EMPTY@") { + continue; + } VLOG(3) << "Accumulating " << s; if (s == framework::kEmptyVarName) continue; std::string tmp_name; diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 99716ccb24017b027e430f7071d90f5f4069ea9b..131971099ef3febc3cfaff30e918fa74cfc6cfe4 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -239,7 +239,8 @@ void BindVarDsec(py::module &m) { .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY) .value("PLACE_LIST", proto::VarType::PLACE_LIST) - .value("READER", proto::VarType::READER); + .value("READER", proto::VarType::READER) + .value("NCCL_COM", proto::VarType::NCCL_COM); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 26b35cfc19001df67dc8a4004e6d90c8e4dbdfa1..33ff43f69304ddd9330c61114dba85994b5f1bdd 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -199,12 +199,76 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): return op_descs +import proto.framework_pb2 as framework_pb2 + + +def serialize_op_decs(op_desc): + protostr = op_desc.serialize_to_string() + proto = framework_pb2.OpDesc.FromString(str(protostr)) + return proto.__str__() + + +def _callback_lookup_(op): + """ + Only used in _append_backward_ops_ + Build and returns a callback function for certain op. For example + + parallel_do: AllReduce + + :param op: + :return: callback function + """ + if op.type == 'parallel_do' and op.attr('use_nccl'): + param_names = set(op.input('parameters')) + param_grad_names = [n + "@GRAD" for n in param_names] + + class ParallelDoCallBack(object): + def __init__(self, param_grad_names, parallel_scopes_name): + self.has_inserted_nccl_init = False + self.param_grad_names = param_grad_names + self.parallel_scopes_name = parallel_scopes_name + + def __call__(self, block, context): + if not self.has_inserted_nccl_init: + op_desc = _create_op_desc_( + "ncclInit", + {"parallel_scopes": self.parallel_scopes_name}, + {"Communicator": ['nccl_com__do_not_change_']}, {}) + block.program.global_block().desc.append_op().copy_from( + op_desc) + self.has_inserted_nccl_init = True + + current_op_desc = context["__current_op_desc__"] + for o_param in current_op_desc.output_names(): + for o_argu in current_op_desc.output(o_param): + if o_argu in self.param_grad_names: + allreduce_out_name = o_argu + "__nccl_all_reduce__" + op_desc = _create_op_desc_( + "ncclAllReduce", { + "X": [o_argu], + "Communicator": + ['nccl_com__do_not_change_'] + }, {"Out": [allreduce_out_name]}, + {"reduction": "ncclSum"}) + block.desc.append_op().copy_from(op_desc) + + op_desc = _create_op_desc_( + "assign", {"X": [allreduce_out_name]}, + {"Out": [o_argu]}, {}) + block.desc.append_op().copy_from(op_desc) + + return ParallelDoCallBack(param_grad_names, + op.output("parallel_scopes")) + else: + return None + + def _append_backward_ops_(block, ops, target_block, no_grad_dict, grad_to_var, - callback=None): + callbacks=None): """ Create all grad ops, and insert them into given block @@ -220,14 +284,11 @@ def _append_backward_ops_(block, val(str): corresponding forward variable name callback(callable object): a callable object used to decorate new generated grad ops """ - if callback is None: - - def empty_callback(block, context): - pass - - callback = empty_callback - elif not hasattr(callback, '__call__'): - raise ValueError("'callback' must be a callable object.") + if callbacks is not None: + assert (isinstance(callbacks, list)) + for cb in callbacks: + if not hasattr(cb, '__call__'): + raise ValueError("'callback' must be a callable object.") # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] @@ -238,8 +299,17 @@ def _append_backward_ops_(block, if op.has_attr("sub_block"): sub_block = program.block(op.block_attr("sub_block")) grad_sub_block = program.create_block(parent_idx=sub_block.idx) - _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, - no_grad_dict, grad_to_var) + cb = _callback_lookup_(op) + if cb is not None: + if callbacks is None: + new_callbacks = [cb] + else: + new_callbacks = callbacks + [_callback_lookup_(op)] + _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, + no_grad_dict, grad_to_var, new_callbacks) + else: + _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, + no_grad_dict, grad_to_var, callbacks) grad_sub_block_list.append(grad_sub_block.desc) # Getting op's corresponding grad_op @@ -258,7 +328,11 @@ def _append_backward_ops_(block, for op_desc in grad_op_descs: new_op_desc = target_block.desc.append_op() new_op_desc.copy_from(op_desc) - callback(block=target_block, context=grad_to_var) + grad_to_var["__current_op_desc__"] = new_op_desc + if callbacks is not None: + assert (isinstance(callbacks, list)) + for cb in callbacks: + cb(block=target_block, context=grad_to_var) def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): @@ -296,6 +370,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): # infer_shape and infer_type op_desc.infer_var_type(block.desc) op_desc.infer_shape(block.desc) + # ncclInit dones't need to set data_type + if op_desc.type() == 'ncclInit': + continue for arg in op_desc.output_arg_names(): if arg in new_vars: _infer_var_data_type_(arg, block) @@ -335,7 +412,8 @@ def _get_stop_gradients_(program): return no_grad_dict -def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): +def append_backward(loss, parameter_list=None, no_grad_set=None, + callbacks=None): """ Append backward part to main_program @@ -351,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): (list[(Variable,Variable)]): list of (parameter, gradient) pair. """ assert isinstance(loss, framework.Variable) + if callbacks is not None: + isinstance(callbacks, list) program = loss.block.program if no_grad_set is None: @@ -378,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None): no_grad_dict[0].update(map(_append_grad_suffix_, block_no_grad_set)) _append_backward_ops_(root_block, op_path, root_block, no_grad_dict, - grad_to_var, callback) + grad_to_var, callbacks) # Because calc_gradient may be called multiple times, # we need rename the internal gradient variables so that they have diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index fb4cd5b75ad74847287a5d6c9ea4feee6fd46659..0e11709296a4fc7121611c1f9928314810f35783 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -490,7 +490,7 @@ class Operator(object): 'feed', 'fetch', 'save', 'load', 'recurrent', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', 'parallel_do', 'save_combine', - 'load_combine' + 'load_combine', 'ncclInit' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index b56a391618b6c493965ff6b41bfb69e74bcb7881..b9ab28a86a226c3027b2a449fd645d500d39f14b 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -237,12 +237,13 @@ class ParallelDo(object): ParallelDo class is used to create a ParallelDo. """ - def __init__(self, places, name=None): + def __init__(self, places, use_nccl=False, name=None): self.helper = LayerHelper("parallel_do", name=name) self.inputs = [] self.places = places self.outputs = [] self.status = StaticRNN.BEFORE_RNN_BLOCK + self.use_nccl = use_nccl def do(self): return BlockGuardWithCompletion(self) @@ -325,7 +326,8 @@ class ParallelDo(object): }, outputs={'outputs': outputs, 'parallel_scopes': [step_scope]}, - attrs={'sub_block': current_block}) + attrs={'sub_block': current_block, + 'use_nccl': self.use_nccl}) class BlockGuardWithCompletion(BlockGuard): diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 39391eb8e40ceea1404352271e8b4a04dc85f535..ecc42f6215bdd13f6ea4284dcd67b6026ad33129 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -225,7 +225,7 @@ class Optimizer(object): `create_optimization_pass()` into one. """ params_grads = append_backward(loss, parameter_list, no_grad_set, - error_clip_callback) + [error_clip_callback]) params_grads = append_gradient_clip_ops(params_grads) diff --git a/python/paddle/v2/fluid/tests/test_error_clip.py b/python/paddle/v2/fluid/tests/test_error_clip.py index b331f16913d0ffa5846f36a0e2d10f8f03728b15..d577d0014dc136ee5ef92155e37009df60d9bf62 100644 --- a/python/paddle/v2/fluid/tests/test_error_clip.py +++ b/python/paddle/v2/fluid/tests/test_error_clip.py @@ -43,7 +43,7 @@ prog_clip.block(0).var(hidden1.name).set_error_clip( avg_cost_clip = prog_clip.block(0).var(avg_cost.name) fluid.backward.append_backward(loss=avg_cost) fluid.backward.append_backward( - loss=avg_cost_clip, callback=fluid.clip.error_clip_callback) + loss=avg_cost_clip, callbacks=[fluid.clip.error_clip_callback]) hidden1_grad = prog.block(0).var(hidden1.name + "@GRAD") hidden1_grad_clip = prog_clip.block(0).var(hidden1.name + "@GRAD") diff --git a/python/paddle/v2/fluid/tests/unittests/test_parallel_op.py b/python/paddle/v2/fluid/tests/unittests/test_parallel_op.py index 0d377ae70ccf709cf42fc40bf4c4cfcc6264382e..d65752608b204454d9d3e529dad366084f9b2c0e 100644 --- a/python/paddle/v2/fluid/tests/unittests/test_parallel_op.py +++ b/python/paddle/v2/fluid/tests/unittests/test_parallel_op.py @@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase): fetch=fetch, place=gpu, use_parallel=True) + result_gpu_nccl = self._run_test_impl_( + callback=callback, + feed=feed, + fetch=fetch, + place=gpu, + use_parallel=True, + use_nccl=True) self._assert_same_(fetch, result_cpu, result_cpu_parallel, - result_gpu, result_gpu_parallel) + result_gpu, result_gpu_parallel, result_gpu_nccl) else: self._assert_same_(fetch, result_cpu, result_cpu_parallel) - def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False): + def _run_test_impl_(self, + callback, + feed, + fetch, + place, + use_parallel=False, + use_nccl=False): """ Run a single test, returns the fetch values Args: @@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase): # Automatically insert parallel do if use_parallel = True if use_parallel: places = fluid.layers.get_places() - pd = fluid.layers.ParallelDo(places) + pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl) data = next(generator) if isinstance(data, fluid.Variable): @@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase): """ def _impl_(a, b, fetch_id, item_id): - item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU'] + item_str = [ + 'CPU', 'ParallelCPU', 'GPU', 'ParallelGPU', 'ParallelGPUNCCL' + ] flag = numpy.allclose(a, b, rtol=0.1, atol=1e-3) self.assertTrue(flag, "The {0} are different in {1}, {2} vs {3}".format( @@ -198,5 +213,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest): fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD']) -#if __name__ == '__main__': -# unittest.main() +if __name__ == '__main__': + unittest.main()