From b7128bac5f12138062ec2518a0f856915c752a69 Mon Sep 17 00:00:00 2001 From: HaoRen Date: Thu, 27 Jun 2019 10:17:10 +0800 Subject: [PATCH] supports collective communicated training (#18175) * fix prepare context redundant code problem, optimize executor by caching create_varaiables test=develop * supports collective training in executor * make fetch_list runable with variables, add more unittest for use_program_cache test=develop * fix comment test=develop * use unique name for nccl_id * supports output to stream in program_to_code * insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code * set op role in collective training * add collective op role * remove orig file * add build optimizer by strategy * add collective strategy * refine collective strategy * add multi-process role maker * refine strategy building factory so that we can easily plugin more strategy * scale loss grad in collective sgd transpiler * add support for distributed fc * code format * revert some features for dist fc * add support for distributed fc training * fix prepare context redundant code problem, optimize executor by caching create_varaiables test=develop * supports collective training in executor * make fetch_list runable with variables, add more unittest for use_program_cache test=develop * use unique name for nccl_id * supports output to stream in program_to_code * insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code * set op role in collective training * add collective op role * fix comment test=develop * remove orig file * add build optimizer by strategy * add collective strategy * refine collective strategy * add multi-process role maker * refine strategy building factory so that we can easily plugin more strategy * scale loss grad in collective sgd transpiler * add support for distributed fc * code format * revert some features for dist fc * add support for distributed fc training * test=develop add collective op unittest standard * test=develop remove the test_collective directory * test=develop remove the test_collective directory * remove slicegather test * code format for reducescatter * update attr of shard_index_op * Modify macro nccl_helper * remove test without distribute * macro collective_helper * marcro update * test=develop update support python3.5 * test=develop change gpu memory use to 0.1 when test * test=develop update ut equal func * test=develop set flags to 1.5 * test=develop fix pickle dumple py35 * test=develop fix divide in slice and add sync_comm_stream update atol and rtol to 1e-05 rm shard_index op and test modify read input from file to read from memory remove origin_program in framework and add i/o in c_sync_calc_stream * test=develop update unittest sync operator I/O --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/op_proto_maker.cc | 2 + paddle/fluid/framework/op_proto_maker.h | 3 + paddle/fluid/operators/CMakeLists.txt | 1 + .../fluid/operators/collective/CMakeLists.txt | 39 ++ .../operators/collective/c_allgather_op.cc | 89 +++++ .../operators/collective/c_allgather_op.cu.cc | 25 ++ .../operators/collective/c_allgather_op.h | 75 ++++ .../operators/collective/c_allreduce_op.cc | 83 +++++ .../operators/collective/c_allreduce_op.cu.cc | 25 ++ .../operators/collective/c_allreduce_op.h | 86 +++++ .../operators/collective/c_broadcast_op.cc | 74 ++++ .../operators/collective/c_broadcast_op.cu.cc | 25 ++ .../operators/collective/c_broadcast_op.h | 92 +++++ .../operators/collective/c_comm_init_op.cc | 86 +++++ .../operators/collective/c_gen_nccl_id_op.cc | 146 ++++++++ .../collective/c_reducescatter_op.cc | 93 +++++ .../collective/c_reducescatter_op.cu.cc | 26 ++ .../operators/collective/c_reducescatter_op.h | 75 ++++ .../collective/c_sync_calc_stream_op.cc | 76 ++++ .../collective/c_sync_comm_stream_op.cc | 77 ++++ .../operators/distributed_ops/CMakeLists.txt | 1 - paddle/fluid/operators/one_hot_op.cc | 5 + paddle/fluid/operators/one_hot_op.cu | 2 +- paddle/fluid/operators/one_hot_op.h | 36 +- paddle/fluid/platform/CMakeLists.txt | 6 +- paddle/fluid/platform/collective_helper.cc | 101 ++++++ paddle/fluid/platform/collective_helper.h | 109 ++++++ paddle/fluid/platform/dynload/nccl.h | 1 + paddle/fluid/pybind/const_value.cc | 1 + python/paddle/fluid/framework.py | 4 +- .../fluid/incubate/fleet/base/fleet_base.py | 1 - .../fluid/incubate/fleet/base/role_maker.py | 40 ++ .../incubate/fleet/collective/__init__.py | 194 +++++++++- .../fluid/tests/unittests/CMakeLists.txt | 8 + .../unittests/collective_allgather_op.py | 69 ++++ .../unittests/collective_allreduce_op.py | 70 ++++ .../unittests/collective_broadcast_op.py | 70 ++++ .../unittests/collective_reducescatter_op.py | 70 ++++ .../fluid/tests/unittests/test_allgather.py | 31 ++ .../fluid/tests/unittests/test_allreduce.py | 31 ++ .../fluid/tests/unittests/test_broadcast.py | 31 ++ .../tests/unittests/test_collective_base.py | 273 ++++++++++++++ .../tests/unittests/test_reducescatter.py | 32 ++ python/paddle/fluid/transpiler/collective.py | 343 ++++++++++++++++++ .../fluid/transpiler/details/program_utils.py | 28 +- .../fluid/transpiler/distribute_transpiler.py | 58 ++- 47 files changed, 2786 insertions(+), 29 deletions(-) create mode 100644 paddle/fluid/operators/collective/CMakeLists.txt create mode 100644 paddle/fluid/operators/collective/c_allgather_op.cc create mode 100644 paddle/fluid/operators/collective/c_allgather_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_allgather_op.h create mode 100644 paddle/fluid/operators/collective/c_allreduce_op.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_allreduce_op.h create mode 100644 paddle/fluid/operators/collective/c_broadcast_op.cc create mode 100644 paddle/fluid/operators/collective/c_broadcast_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_broadcast_op.h create mode 100644 paddle/fluid/operators/collective/c_comm_init_op.cc create mode 100644 paddle/fluid/operators/collective/c_gen_nccl_id_op.cc create mode 100644 paddle/fluid/operators/collective/c_reducescatter_op.cc create mode 100644 paddle/fluid/operators/collective/c_reducescatter_op.cu.cc create mode 100644 paddle/fluid/operators/collective/c_reducescatter_op.h create mode 100644 paddle/fluid/operators/collective/c_sync_calc_stream_op.cc create mode 100644 paddle/fluid/operators/collective/c_sync_comm_stream_op.cc create mode 100644 paddle/fluid/platform/collective_helper.cc create mode 100644 paddle/fluid/platform/collective_helper.h create mode 100644 python/paddle/fluid/tests/unittests/collective_allgather_op.py create mode 100644 python/paddle/fluid/tests/unittests/collective_allreduce_op.py create mode 100644 python/paddle/fluid/tests/unittests/collective_broadcast_op.py create mode 100644 python/paddle/fluid/tests/unittests/collective_reducescatter_op.py create mode 100644 python/paddle/fluid/tests/unittests/test_allgather.py create mode 100644 python/paddle/fluid/tests/unittests/test_allreduce.py create mode 100644 python/paddle/fluid/tests/unittests/test_broadcast.py create mode 100644 python/paddle/fluid/tests/unittests/test_collective_base.py create mode 100644 python/paddle/fluid/tests/unittests/test_reducescatter.py create mode 100644 python/paddle/fluid/transpiler/collective.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 65367a2120..65bbbd77a9 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -179,7 +179,7 @@ if(WITH_DISTRIBUTE) data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer - lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} + lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 2311614c33..27922c7304 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_proto_maker.h" #include +#include #include namespace paddle { @@ -73,6 +74,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, static_cast(OpRole::kBackward), static_cast(OpRole::kOptimize) | static_cast(OpRole::kLRSched), + static_cast(OpRole::kCollective), static_cast(OpRole::kNotSpecified)}) .SetDefault(static_cast(OpRole::kNotSpecified)); AddAttr>(OpRoleVarAttrName(), diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 5f3ce60e1d..bf6528b237 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -34,6 +34,9 @@ enum class OpRole { kDist = 0x0008, // Tag all learning rate scheduler operators. kLRSched = 0x0010, + // Collective role is for all collective operators and other operators used + // for collective training + kCollective = 0x0020, kLoss = 0x0100, // The default value of op's role. This should be only used for unittests and diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 3356c1e669..41658dec85 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(jit) if(WITH_DISTRIBUTE) add_subdirectory(distributed) add_subdirectory(distributed_ops) + add_subdirectory(collective) endif() add_subdirectory(reader) diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt new file mode 100644 index 0000000000..89103f63d0 --- /dev/null +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -0,0 +1,39 @@ +include(operators) + +set(COLLECTIVE_DEPS "") +if(WITH_GRPC) + set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node) +else() + set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node) + if(WITH_BRPC_RDMA) + find_library(IBVERBS_LIBRARY NAMES ibverbs) + ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY}) + + + find_library(RDMACM_LIBRARY NAMES rdmacm) + ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL) + SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY}) + + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} ibverbs rdmacm) + endif() +endif() + +set(COLLECTIVE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + +file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") +list(REMOVE_DUPLICATES OPS) + +foreach(src ${OPS}) + set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS}) +endforeach() + +register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) + +if(WITH_GPU AND NOT WIN32) + set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) + op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common) +endif() + +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE) +set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency") diff --git a/paddle/fluid/operators/collective/c_allgather_op.cc b/paddle/fluid/operators/collective/c_allgather_op.cc new file mode 100644 index 0000000000..6f915953da --- /dev/null +++ b/paddle/fluid/operators/collective/c_allgather_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allgather_op.h" +#include // NOLINT +#include +#include + +namespace paddle { +namespace operators { + +class CAllGatherOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SyncFCGather op should not be null."); + int nranks = ctx->Attrs().Get("nranks"); + PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2"); + framework::DDim dim = ctx->GetInputDim("X"); + dim[0] = dim[0] * nranks; + ctx->SetOutputDim("Out", dim); + } +}; + +class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be allgather"); + AddOutput("Out", "(Tensor) the allgather result"); + AddAttr("ring_id", "(int default 0) communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddAttr("nranks", + "Total trainer count of the distributed training job"); + AddComment(R"DOC( +***CAllGather Operator*** +each rank receives the aggregation of data from all ranks in the order of the ranks + +Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather +)DOC"); + } +}; + +class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new framework::OpDesc()); + retv->SetType("c_reducescatter"); + retv->SetInput("X", OutputGrad("Out")); + retv->SetOutput("Out", InputGrad("X")); + retv->SetAttrMap(Attrs()); + return retv; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker, + ops::CAllGatherOpMaker); + +REGISTER_OP_CPU_KERNEL( + c_allgather, ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc new file mode 100644 index 0000000000..8b13ceeb40 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allgather_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_allgather, ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel, + ops::CAllGatherOpKernel); diff --git a/paddle/fluid/operators/collective/c_allgather_op.h b/paddle/fluid/operators/collective/c_allgather_op.h new file mode 100644 index 0000000000..8becbba018 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allgather_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CAllGatherOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), + "CAllGatherOp can run on gpu place only for now."); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + int nranks = comm->nranks(); + + framework::DDim out_dims = in->dims(); + out_dims[0] *= nranks; + out->mutable_data(out_dims, place); + + int64_t send_numel = in->numel(); + const T* send_buff = in->data(); + T* recv_buff = out->data(); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE(platform::dynload::ncclAllGather( + send_buff, recv_buff, send_numel, static_cast(dtype), + comm->comm(), stream)); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_allreduce_op.cc b/paddle/fluid/operators/collective/c_allreduce_op.cc new file mode 100644 index 0000000000..8af1135701 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_op.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { + +class CAllReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor), tensor to be allreduced."); + AddOutput("Out", "(Tensor) the allreduced result."); + AddAttr("reduce_type", "(int default 0) determin the reduce type.") + .SetDefault(0); + AddAttr("ring_id", "(int default 0) communication ring id.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +***CAllReduce Operator*** + +Call NCCL collective AllReduce internally. Note that this op must be used when one +thread is managing one GPU device. + +For speed reasons, reduce_type should be an integer: + +0: sum +1: prod +2: max +3: min +If input and output are the same variable, in-place allreduce will be used. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce, ops::CAllReduceOp, + ops::CAllReduceOpMaker); + +REGISTER_OP_CPU_KERNEL( + c_allreduce, ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_op.cu.cc new file mode 100644 index 0000000000..8b3246d95a --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_allreduce, ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel, + ops::CAllReduceOpKernel); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h new file mode 100644 index 0000000000..0cd4b857ff --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CAllReduceOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), + "CAllReduce op can run on gpu place only for now."); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + out->Resize(in->dims()); + void* recvbuff = out->mutable_data(place); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + + int reduce_type = ctx.Attr("reduce_type"); + ncclRedOp_t red_type = ncclSum; + switch (reduce_type) { + case 0: + red_type = ncclSum; + break; + case 1: + red_type = ncclProd; + break; + case 2: + red_type = ncclMax; + break; + case 3: + red_type = ncclMin; + break; + } + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + sendbuff, recvbuff, numel, dtype, red_type, comm->comm(), stream)); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cc b/paddle/fluid/operators/collective/c_broadcast_op.cc new file mode 100644 index 0000000000..ab8ed3d869 --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +namespace paddle { +namespace operators { + +class CBroadcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.GetPlace()); + } +}; + +class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be broadcasted."); + AddOutput("Out", "(Tensor) the result of broadcast."); + AddAttr("ring_id", "(int default 0) nccl communication ring id.") + .SetDefault(0); + AddAttr("root", "(int default 0) root id for broadcasting.") + .SetDefault(0); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +***CBroadcast Operator*** + +Call ncclBcast internally. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp, + ops::CBroadcastOpMaker); + +REGISTER_OP_CPU_KERNEL( + c_broadcast, ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc new file mode 100644 index 0000000000..23b0fb01ec --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_broadcast_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_broadcast, ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel, + ops::CBroadcastOpKernel); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.h b/paddle/fluid/operators/collective/c_broadcast_op.h new file mode 100644 index 0000000000..c93c459b75 --- /dev/null +++ b/paddle/fluid/operators/collective/c_broadcast_op.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CBroadcastOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), + "CBroadcastOp can run on gpu place only for now."); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + int numel = x->numel(); + ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + int root = ctx.Attr("root"); + int nranks = comm->nranks(); + PADDLE_ENFORCE(root >= 0 && root < nranks, + "Expected root in range of [0,%d),but get %d", nranks, root); + if (root == comm->rank()) { + PADDLE_ENFORCE(platform::dynload::ncclBcast( + reinterpret_cast(const_cast(x->data())), numel, dtype, + root, comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + + if (out != x) { + // TODO(liuyi05): check inplace + framework::TensorCopy( + *static_cast(x), place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data(place), + numel, dtype, root, + comm->comm(), stream)); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved " + << framework::product(out->dims()); + } + + out->Resize(x->dims()); + out->set_lod(x->lod()); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc new file mode 100644 index 0000000000..9dace1725f --- /dev/null +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include +#endif +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif +namespace paddle { +namespace operators { + +class CCommInitOp : public framework::OperatorBase { + public: + CCommInitOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE(is_gpu_place(place), + "CCommInitOp can run on gpu place only."); + + auto var = scope.FindVar(Input("X")); + PADDLE_ENFORCE_NOT_NULL(var); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + ncclUniqueId* nccl_id = var->GetMutable(); + + int nranks = Attr("nranks"); + int rank_id = Attr("rank"); + int rid = Attr("ring_id"); + + platform::NCCLCommContext::Instance().CreateNCCLComm( + nccl_id, nranks, rank_id, boost::get(place).device, + rid); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +class CCommInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +CCommInit operator + +Initialize collective communicatoin context within this trainer +)DOC"); + AddAttr("nranks", "(int) The number of ranks of distributed trainers"); + AddAttr("rank", + "(int) The rank of the trainer in distributed training."); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_comm_init, ops::CCommInitOp, ops::CCommInitOpMaker); diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc new file mode 100644 index 0000000000..a19a3fe1a3 --- /dev/null +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include +#endif +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/distributed/distributed.h" +#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif +namespace paddle { +namespace operators { + +class CGenNCCLIdOp : public framework::OperatorBase { + public: + CGenNCCLIdOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + // put nccl id in CPUPlace + auto& dev_ctx = *pool.Get(platform::CPUPlace()); + int rank = Attr("rank"); + framework::Scope& local_scope = scope.NewScope(); + + if (rank == 0) { + GenerateAndSend(&local_scope, dev_ctx); + } else { + GetIdByServer(&local_scope, dev_ctx); + } + scope.DeleteScope(&local_scope); + } + + private: + void GenerateAndSend(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + std::string var_name = Output("Out"); + auto var = scope->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var); + auto id = var->GetMutable(); + PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id)); + + std::vector endpoint_list = + Attr>("other_endpoints"); + distributed::RPCClient* client = + distributed::RPCClient::GetInstance(0); + + for (auto& ep : endpoint_list) { + VLOG(3) << "sending nccl id to " << ep; + client->AsyncSendVar(ep, dev_ctx, *scope, var_name); + } + client->Wait(); + for (auto& ep : endpoint_list) { + client->AsyncSendBatchBarrier(ep); + } + client->Wait(); + VLOG(3) << "sending completed..."; + } + + void GetIdByServer(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + std::string endpoint = Attr("endpoint"); + // NOTE: Can not use unique_ptr here because the default + // deleter will call GRPC Server's base class's dtor and + // that will cause a wired crash. + distributed::RequestSendHandler rpc_h(true); + std::unique_ptr rpc_service( + new RPCSERVER_T(endpoint, 1)); + + rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); + rpc_h.SetRPCServer(rpc_service.get()); + + framework::ProgramDesc empty_program; + framework::Executor executor(dev_ctx.GetPlace()); + rpc_h.SetScope(scope); + rpc_h.SetDevCtx(&dev_ctx); + rpc_h.SetProgram(&empty_program); + rpc_h.SetExecutor(&executor); + + std::thread server_thread( + std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); + + rpc_service->SetCond(distributed::kRequestSend); + VLOG(3) << "start getting nccl id from trainer 0..."; + rpc_service->WaitBarrier(distributed::kRequestSend); + VLOG(3) << "got nccl id and stop server..."; + rpc_service->ShutDown(); + VLOG(3) << "rpc server stopped"; + server_thread.join(); + } +}; + +class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +CGenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr("endpoint", + "(string), e.g. 127.0.0.1:6175 " + "current listen endpoint"); + AddAttr>( + "other_endpoints", + "['trainer1_ip:port', 'trainer2_ip:port', ...] " + "list of other trainer endpoints") + .SetDefault({}); + AddAttr("rank", + "(int default 0) " + "The rank of the trainer in distributed training.") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_gen_nccl_id, ops::CGenNCCLIdOp, ops::CGenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cc new file mode 100644 index 0000000000..feb9dcd5a4 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reducescatter_op.h" +#include // NOLINT +#include +#include + +namespace paddle { +namespace operators { + +class CReduceScatterOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); + int nranks = ctx->Attrs().Get("nranks"); + framework::DDim dim = ctx->GetInputDim("X"); + if (dim[0] > 0 || dim[0] < -1) { + PADDLE_ENFORCE(dim[0] % nranks == 0, + "dim[0] (%d) is not divisible by nranks(%d)", dim[0], + nranks); + dim[0] /= nranks; + } + ctx->SetOutputDim("Out", dim); + } +}; + +class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) tensor to be allgather"); + AddOutput("Out", "(Tensor) the allgather result"); + AddAttr("ring_id", "(int default 0) communication ring id.") + .SetDefault(0); + AddAttr("nranks", + "Total trainer count of the distributed training job") + .SetDefault(1); + AddAttr( + "use_calc_stream", + "(bool default false) eject CUDA operations to calculation stream.") + .SetDefault(false); + AddComment(R"DOC( +***CReduceScatter Operator*** + +Call NCCL collective ReduceScatter internally. +)DOC"); + } +}; + +class CReduceScatterOpGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new framework::OpDesc()); + retv->SetType("c_allgather"); + retv->SetInput("X", OutputGrad("Out")); + retv->SetOutput("Out", InputGrad("X")); + retv->SetAttrMap(Attrs()); + return retv; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp, + ops::CReduceScatterOpMaker); + +REGISTER_OP_CPU_KERNEL( + c_reducescatter, ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc new file mode 100644 index 0000000000..ef9eed2aab --- /dev/null +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reducescatter_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + c_reducescatter, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel, + ops::CReduceScatterOpKernel); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.h b/paddle/fluid/operators/collective/c_reducescatter_op.h new file mode 100644 index 0000000000..93d623ff2e --- /dev/null +++ b/paddle/fluid/operators/collective/c_reducescatter_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CReduceScatterOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto place = ctx.GetPlace(); + PADDLE_ENFORCE(is_gpu_place(place), + "CAllReduce op can run on gpu place only for now."); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + + int rid = ctx.Attr("ring_id"); + auto comm = platform::NCCLCommContext::Instance().Get(rid); + int nranks = comm->nranks(); + + auto out_dims = in->dims(); + out_dims[0] = out_dims[0] / nranks; + out->mutable_data(out_dims, place); + + int64_t recv_numel = in->numel() / nranks; + const T* send_buff = in->data(); + T* recv_buff = out->data(); + int dtype = platform::ToNCCLDataType(in->type()); + + cudaStream_t stream = nullptr; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx)->stream(); + } else { + stream = comm->stream(); + } + + PADDLE_ENFORCE(platform::dynload::ncclReduceScatter( + send_buff, recv_buff, recv_numel, static_cast(dtype), + ncclSum, comm->comm(), stream)); +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc new file mode 100644 index 0000000000..965761dc15 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include +#endif +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CSyncCalcStreamOp : public framework::OperatorBase { + public: + CSyncCalcStreamOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE(is_gpu_place(place), + "Sync stream op can run on gpu place only for now."); + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + cudaError_t e_sync = cudaStreamSynchronize(dev_ctx->stream()); + if (e_sync != 0) { + LOG(FATAL) << "Fail to sync cuda stream: " << cudaGetErrorString(e_sync); + } +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) Dependency of last param need to sync"); + AddOutput("Out", "(Tensor) Dependency of last param need to sync"); + AddComment(R"DOC( +***Sync Operator*** + +Call cuda stream synchronize. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp, + ops::CSyncCalcStreamOpMaker); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc new file mode 100644 index 0000000000..6fbb5b8cb1 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include +#endif +#include +#include +#include + +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace operators { + +class CSyncCommStreamOp : public framework::OperatorBase { + public: + CSyncCommStreamOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + PADDLE_ENFORCE(is_gpu_place(place), + "Sync stream op can run on gpu place only for now."); + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + int ring_id = Attr("ring_id"); + auto stream = platform::NCCLCommContext::Instance().Get(ring_id)->stream(); + cudaError_t e_sync = cudaStreamSynchronize(stream); + if (e_sync != 0) { + LOG(FATAL) << "Fail to sync nccl stream: " << cudaGetErrorString(e_sync); + } +#else + PADDLE_THROW("PaddlePaddle should compile with GPU."); +#endif + } +}; + +class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) Dependency of last param need to sync"); + AddOutput("Out", "(Tensor) Dependency of last param need to sync"); + AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); + AddComment(R"DOC( +***Sync Operator*** + +Call nccl stream synchronize. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp, + ops::CSyncCommStreamOpMaker); diff --git a/paddle/fluid/operators/distributed_ops/CMakeLists.txt b/paddle/fluid/operators/distributed_ops/CMakeLists.txt index 1096f3773c..829e67a53b 100644 --- a/paddle/fluid/operators/distributed_ops/CMakeLists.txt +++ b/paddle/fluid/operators/distributed_ops/CMakeLists.txt @@ -21,7 +21,6 @@ endif() set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") - file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") list(REMOVE_DUPLICATES OPS) diff --git a/paddle/fluid/operators/one_hot_op.cc b/paddle/fluid/operators/one_hot_op.cc index cbb0c4028b..6042b97bf5 100644 --- a/paddle/fluid/operators/one_hot_op.cc +++ b/paddle/fluid/operators/one_hot_op.cc @@ -86,6 +86,11 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker { "An integer to specify the data type of one-hot " "vector. The default value is FP32.") .SetDefault(paddle::framework::proto::VarType::FP32); + AddAttr("allow_out_of_range", + "If it is set true and the input data is out of range, " + "the output tensor will be filled zeros. The default value " + "is false.") + .SetDefault(false); AddComment(R"DOC( One Hot Operator. This operator creates the one-hot representations for input index values. The following example will help to explain the function of this diff --git a/paddle/fluid/operators/one_hot_op.cu b/paddle/fluid/operators/one_hot_op.cu index b9fe0bf2e9..bffd1d5305 100644 --- a/paddle/fluid/operators/one_hot_op.cu +++ b/paddle/fluid/operators/one_hot_op.cu @@ -24,7 +24,7 @@ template __global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data, const int64_t numel, const int depth) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel) { + if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) { *(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0; } } diff --git a/paddle/fluid/operators/one_hot_op.h b/paddle/fluid/operators/one_hot_op.h index 7273080927..0e2284941b 100644 --- a/paddle/fluid/operators/one_hot_op.h +++ b/paddle/fluid/operators/one_hot_op.h @@ -25,10 +25,16 @@ struct OneHotOpFunctor { framework::LoDTensor* out_; int depth_; const DeviceContext& ctx_; + bool allow_out_of_range_; OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, - int depth, const DeviceContext& ctx) - : in_(in), out_(out), depth_(depth), ctx_(ctx) {} + int depth, const DeviceContext& ctx, + bool allow_out_of_range = false) + : in_(in), + out_(out), + depth_(depth), + ctx_(ctx), + allow_out_of_range_(allow_out_of_range) {} template void apply() const { @@ -37,13 +43,21 @@ struct OneHotOpFunctor { auto* p_out_data = out_->mutable_data(ctx_.GetPlace()); math::set_constant(ctx_, out_, 0.0); - for (int i = 0; i < numel; ++i) { - PADDLE_ENFORCE_GE(p_in_data[i], 0, - "Illegal index value, should be at least 0."); - PADDLE_ENFORCE_LT(p_in_data[i], depth_, - "Illegal index value, should be less than depth (%d).", - depth_); - *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + if (allow_out_of_range_) { + for (int i = 0; i < numel; ++i) { + if (p_in_data[i] >= 0 && p_in_data[i] < depth_) { + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } + } else { + for (int i = 0; i < numel; ++i) { + PADDLE_ENFORCE_GE(p_in_data[i], 0, + "Illegal index value, should be at least 0."); + PADDLE_ENFORCE_LT( + p_in_data[i], depth_, + "Illegal index value, should be less than depth (%d).", depth_); + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } } } }; @@ -57,6 +71,7 @@ class OneHotKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); int depth = context.Attr("depth"); + bool allow_out_of_range = context.Attr("allow_out_of_range"); if (context.HasInput("depth_tensor")) { auto* depth_tensor = context.Input("depth_tensor"); auto* depth_data = depth_tensor->data(); @@ -71,7 +86,8 @@ class OneHotKernel : public framework::OpKernel { static_cast( context.Attr("dtype")), OneHotOpFunctor( - in, out, depth, context.template device_context())); + in, out, depth, context.template device_context(), + allow_out_of_range)); } }; diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 5de00db55a..93f29b57c8 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -74,6 +74,10 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS} temp_allocator ${dgc_deps}) +if (WITH_DISTRIBUTE) + cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto) +endif() + if(WIN32) if(WITH_GPU AND NOT WITH_DSO) get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES) @@ -97,7 +101,7 @@ cc_test(lodtensor_printer_test SRCS lodtensor_printer_test.cc DEPS lodtensor_pri cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS}) if(WITH_GPU) - nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce) + nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce) else() cc_library(profiler SRCS profiler.cc DEPS device_tracer enforce) endif() diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc new file mode 100644 index 0000000000..49f3e0c736 --- /dev/null +++ b/paddle/fluid/platform/collective_helper.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #ifndef _WIN32 +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/collective_helper.h" + +#include + +#include "paddle/fluid/platform/dynload/nccl.h" + +namespace paddle { +namespace platform { + +class NCCLCommImpl : public NCCLComm { + public: + void set_ring_id(int ring_id) { ring_id_ = ring_id; } + int ring_id() const override { return ring_id_; } + + void set_nranks(int nranks) { nranks_ = nranks; } + int nranks() const override { return nranks_; } + + void set_rank(int rank) { rank_ = rank; } + int rank() const override { return rank_; } + + void set_local_rank(int local_rank) { local_rank_ = local_rank; } + int local_rank() const override { return local_rank_; } + + void set_comm(ncclComm_t comm) { comm_ = comm; } + ncclComm_t comm() const override { return comm_; } + + void set_dev_ctx(CUDADeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + CUDADeviceContext* DevCtx() const override { return dev_ctx_; } + + cudaStream_t stream() const override { return dev_ctx_->stream(); } + + private: + int ring_id_; + int nranks_; + int rank_; + int local_rank_; + ncclComm_t comm_; + CUDADeviceContext* dev_ctx_; +}; + +// NOTE: not thread-safe +NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, + int rank, int dev_id, int ring_id) { + PADDLE_ENFORCE_NOT_NULL(nccl_id); + PADDLE_ENFORCE_GT(nranks, 1); + PADDLE_ENFORCE(rank >= 0 && rank < nranks, + "Expected rank id range [0, %d), but get %d", nranks, rank); + PADDLE_ENFORCE_GE(dev_id, 0); + + if (dev_ctx_map_.count(dev_id) == 0) { + dev_ctx_map_.emplace(dev_id, std::unique_ptr( + new CUDADeviceContext(CUDAPlace(dev_id)))); + } + + ncclComm_t comm = nullptr; + PADDLE_ENFORCE(cudaSetDevice(dev_id)); + PADDLE_ENFORCE( + platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); + + NCCLCommImpl* communicator = new NCCLCommImpl; + communicator->set_ring_id(ring_id); + communicator->set_nranks(nranks); + communicator->set_rank(rank); + communicator->set_local_rank(dev_id); + communicator->set_comm(comm); + communicator->set_dev_ctx(dev_ctx_map_.at(dev_id).get()); + + comm_map_.emplace(ring_id, std::unique_ptr(communicator)); + + VLOG(0) << "nccl communicator of rank " << rank << " in ring " << ring_id + << " has been created"; + + return comm_map_.at(ring_id).get(); +} + +NCCLCommContext::~NCCLCommContext() { + for (auto& p : comm_map_) { + PADDLE_ENFORCE(platform::dynload::ncclCommDestroy(p.second->comm())); + } +} + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h new file mode 100644 index 0000000000..97d9417592 --- /dev/null +++ b/paddle/fluid/platform/collective_helper.h @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #ifndef _WIN32 +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#pragma once + +#include +#include +#include +#include + +#include "boost/variant.hpp" +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +// In order to apply hierarchical communication with NCCL, we need +// a communication ring contains NCCL communicators associated to a global +// ncclUniqueId. E.g. for a hierarchical case, +// +// 11 - 12 21 - 22 +// | | | | +// 13 - 14 - 23 - 24 +// | | +// 31 - 32 - 41 - 42 +// | | | | +// 33 - 34 43 - 44 +// +// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24), +// (31,32,33,34), (41,42,43,44) as bottoms respectively. +// +// We could also use a single communication ring for the flatten case +// +// The NCCLComm instance is created and reversed in the NCCLCommContext +// singleton with a global user specified group id. +class NCCLComm { + public: + virtual int ring_id() const = 0; + virtual int nranks() const = 0; + virtual int rank() const = 0; + virtual int local_rank() const = 0; + virtual ncclComm_t comm() const = 0; + virtual cudaStream_t stream() const = 0; + virtual CUDADeviceContext* DevCtx() const = 0; + virtual ~NCCLComm() = default; +}; + +// a singleton NCCL communicator context reserves communication ring ids +// Assume multiprocessing mode +class NCCLCommContext { + public: + static NCCLCommContext& Instance() { + static NCCLCommContext comm_ctx; + return comm_ctx; + } + ~NCCLCommContext(); + + NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, + int dev_id, int ring_id = 0); + + CUDADeviceContext* DevCtx(int dev_id) const { + PADDLE_ENFORCE(dev_ctx_map_.count(dev_id), + "CUDADeviceContext at device %d has not been initialized"); + return dev_ctx_map_.at(dev_id).get(); + } + + CUDADeviceContext* DevCtx(platform::Place p) const { + return DevCtx(boost::get(p).device); + } + + // retrieve a communicator by the ring id + NCCLComm* Get(int ring_id) const { + PADDLE_ENFORCE(comm_map_.count(ring_id), + "comunicator in ring id %d has not been initialized", + ring_id); + return comm_map_.at(ring_id).get(); + } + + private: + // ring id to NCCLComm + std::unordered_map> comm_map_; + + // device id to CUDADeviceContext + std::unordered_map> dev_ctx_map_; + + NCCLCommContext() = default; + NCCLCommContext(const NCCLCommContext& other) = delete; + NCCLCommContext& operator=(const NCCLCommContext& other) = delete; +}; + +} // namespace platform +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index 331ca9908e..06ee478efd 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -66,6 +66,7 @@ extern void* nccl_dso_handle; __macro(ncclGroupStart); \ __macro(ncclGroupEnd); \ __macro(ncclReduce); \ + __macro(ncclReduceScatter); \ __macro(ncclGetErrorString); NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP) diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 633e3259ad..3f0fe62fec 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -46,6 +46,7 @@ void BindConstValue(pybind11::module* m) { .value("Loss", framework::OpRole::kLoss) .value("RPC", framework::OpRole::kRPC) .value("Dist", framework::OpRole::kDist) + .value("Collective", framework::OpRole::kCollective) .value("LRSched", framework::OpRole::kLRSched); op_proto_and_checker_maker.def( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 07b50c3303..c14cd00239 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1005,7 +1005,9 @@ class Operator(object): OP_WITHOUT_KERNEL_SET = { 'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'listen_and_serv', - 'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id' + 'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id', + 'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream', + 'c_sync_comm_stream' } def __init__(self, diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index acabec3e82..abba985848 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -190,7 +190,6 @@ class Fleet(object): self._role_maker = role_maker self._role_maker.generate_role() - self._is_initialized = True @abc.abstractmethod diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index a5802ac1fe..ae6bf82e62 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -103,6 +103,46 @@ class RoleMakerBase(object): return self._server_endpoints +class MultiProcessRoleMaker(RoleMakerBase): + """ + MultiProcessRoleMaker is a default role maker for multi-process + GPU training. It works with paddle.distributed.lanuch.py by-design + """ + + def __init__(self): + super(MultiProcessRoleMaker, self).__init__() + self._role_is_generated = False + + def generate_role(self): + import os + if not self._role_is_generated: + self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self._num_trainers = 1 + self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") + assert (self._training_role == "TRAINER") + self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") + self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") + if self._worker_endpoints: + self._worker_endpoints = self._worker_endpoints.split(",") + self._num_trainers = len(self._worker_endpoints) + self._role_is_generated = True + + def is_worker(self): + return True + + def is_server(self): + return False + + def is_first_worker(self): + return self._current_id == 0 + + def worker_index(self): + return self._current_id + + def worker_num(self): + return self._worker_num + + class MPIRoleMaker(RoleMakerBase): """ MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker diff --git a/python/paddle/fluid/incubate/fleet/collective/__init__.py b/python/paddle/fluid/incubate/fleet/collective/__init__.py index 100474244c..b9da38fa8a 100644 --- a/python/paddle/fluid/incubate/fleet/collective/__init__.py +++ b/python/paddle/fluid/incubate/fleet/collective/__init__.py @@ -22,6 +22,116 @@ from paddle.fluid.incubate.fleet.base.fleet_base import Mode from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer +class DistributedStrategy(object): + def __init__(self): + # precision configs + self.use_fp16 = False + self.use_fp32 = True + # algorithmic communication + self.local_sgd = False + self.dgc = False + # communication topology configs + self.h_allreduce = False + + def build(self): + # make sure we set single precision config True + if self.use_fp32 and self.use_fp16: + self.use_fp16 = False + # make sure we set single algorithmic communication True + if self.local_sgd and self.dgc: + self.local_sgd = False + self.strategy_map["fp16"] = self.use_fp16 + self.strategy_map["fp32"] = self.use_fp32 + self.strategy_map["localsgd"] = self.local_sgd + self.strategy_map["dgc"] = self.dgc + self.strategy_map["h_allreduce"] = self.h_allreduce + + +class DistributedOptimizerFactory(object): + def strategy_to_optimizer_map(self): + pattern = {} + pattern["fp16"] = [ + "MixedPrecisionOptimizer", "MixedPrecisionLocalSGDOptimizer" + ] + pattern["fp32"] = ["FullPrecisionOptimizer", "LocalSGDOptimizer"] + pattern["localsgd"] = [ + "MixedPrecisionLocalSGDOptimizer", "LocalSGDOptimizer" + ] + pattern["h_allreduce"] = [ + "FullPrecisionOptimizer", + "LocalSGDOptimizer", + "MixedPrecisionOptimizer", + "MixedPrecisionLocalSGDOptimizer", + ] + self.pattern = pattern + + def create_by_strategy(self, optimizer, strategy): + if strategy == None: + strategy = DistributedStrategy() + strategy.build() + strategy_list = [] + for key in strategy.strategy_map: + if strategy.strategy_map[key]: + strategy_list.append(self.pattern[key]) + classname = list(set.intersection(*map(set, strategy_list)))[0] + return globals()[classname](optimizer, strategy) + + +class DistributedStrategy(object): + def __init__(self): + # precision configs + self.use_fp16 = False + self.use_fp32 = True + # algorithmic communication + self.local_sgd = False + self.dgc = False + # communication topology configs + self.h_allreduce = False + + def build(self): + # make sure we set single precision config True + if self.use_fp32 and self.use_fp16: + self.use_fp16 = False + # make sure we set single algorithmic communication True + if self.local_sgd and self.dgc: + self.local_sgd = False + self.strategy_map["fp16"] = self.use_fp16 + self.strategy_map["fp32"] = self.use_fp32 + self.strategy_map["localsgd"] = self.local_sgd + self.strategy_map["dgc"] = self.dgc + self.strategy_map["h_allreduce"] = self.h_allreduce + + +class DistributedOptimizerFactory(object): + def strategy_to_optimizer_map(self): + pattern = {} + pattern["fp16"] = [ + "MixedPrecisionOptimizer", "MixedPrecisionLocalSGDOptimizer" + ] + pattern["fp32"] = ["FullPrecisionOptimizer", "LocalSGDOptimizer"] + pattern["localsgd"] = [ + "MixedPrecisionLocalSGDOptimizer", "LocalSGDOptimizer" + ] + pattern["h_allreduce"] = [ + "FullPrecisionOptimizer", + "LocalSGDOptimizer", + "MixedPrecisionOptimizer", + "MixedPrecisionLocalSGDOptimizer", + ] + self.pattern = pattern + + def create_by_strategy(self, optimizer, strategy): + if strategy == None: + strategy = DistributedStrategy() + strategy.build() + strategy_list = [] + for key in strategy.strategy_map: + if strategy.strategy_map[key]: + strategy_list.append(self.pattern[key]) + classname = list(set.intersection(*map(set, strategy_list)))[0] + return globals()[classname](optimizer, strategy) + + class Collective(Fleet): def __init__(self): super(Collective, self).__init__(Mode.COLLECTIVE) @@ -48,7 +158,8 @@ class Collective(Fleet): "You should not call 'stop_worker' method for collective mode.") def distributed_optimizer(self, optimizer, strategy=None): - self._optimizer = CollectiveOptimizer(optimizer, strategy) + self._optimizer = \ + DistributedOptimizerFactory.create_by_strategy(optimizer, strategy) return self._optimizer def save_inference_model(self, @@ -69,6 +180,85 @@ class Collective(Fleet): fleet = Collective() +class CollectiveOpBasedOptimizer(DistributedOptimizer): + """ + TBA + """ + + def __init__(self, optimizer, strategy=None): + super(CollectiveOpBasedOptimizer, self).__init__(optimizer, strategy) + + def _transpile_program(self, startup_program=None): + startup_program = startup_program if startup_program else \ + fluid.framework.default_startup_program() + worker_endpoints = fleet.worker_endpoints() + trainer_id = fleet.worker_index() + current_endpoint = fleet.worker_endpoints()[trainer_id] + # call transpiler + config = dist_transpiler.DistributeTranspilerConfig() + config.mode = "collective" + config.collective_mode = "sgd" + t = dist_transpiler.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + trainers=','.join(worker_endpoints), + startup_program=startup_program, + current_endpoint=current_endpoint) + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self._optimizer.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def apply_gradients(self, params_grads): + return self._optimizer.apply_gradients(params_grads) + + +class MixedPrecisionOptimizer(CollectiveOpBasedOptimizer): + """ + TBA + """ + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + pass + + +class FullPrecisionOptimizer(CollectiveOpBasedOptimizer): + """ + TBA + """ + + def __init__(self, optimizer, strategy=None): + super(FullPrecisionOptimizer, self).__init__(optimizer, strategy) + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self._transpile_program(startup_program) + + train_program = loss.block.program + param_grads = self.backward(loss) + train_program.global_block().append_op(type='c_sync_compute_stream') + data_parallel_param_grads = [] + for p, g in param_grads: + # NOTE: scale will be done on loss scale + # in multi_devices_graph_pass using nranks. + reduced_g = fluid.layers.collective._allreduce(g, g) + data_parallel_param_grads.append([p, reduced_g]) + train_program.global_block().append_op(type='c_sync_comm_stream') + self.apply_gradients(data_parallel_param_grads) + + class CollectiveOptimizer(DistributedOptimizer): """ DistributedOptimizer is a wrapper for paddle.fluid.optimizer @@ -82,7 +272,7 @@ class CollectiveOptimizer(DistributedOptimizer): def __init__(self, optimizer, strategy=None): super(CollectiveOptimizer, self).__init__(optimizer, strategy) - assert strategy is None, "You cannot set 'strategy' for collective." + self.strategy = strategy def backward(self, loss, diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4babeb2974..30e5fae281 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -20,6 +20,14 @@ if(NOT WITH_DISTRIBUTE) LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr) endif(NOT WITH_DISTRIBUTE) +if(NOT WITH_GPU OR WIN32) + LIST(REMOVE_ITEM TEST_OPS test_allgather) + LIST(REMOVE_ITEM TEST_OPS test_allreduce) + LIST(REMOVE_ITEM TEST_OPS test_broadcast) + LIST(REMOVE_ITEM TEST_OPS test_reducescatter) +endif() + + LIST(REMOVE_ITEM TEST_OPS test_launch) if (NOT ${WITH_GPU}) diff --git a/python/paddle/fluid/tests/unittests/collective_allgather_op.py b/python/paddle/fluid/tests/unittests/collective_allgather_op.py new file mode 100644 index 0000000000..3499965476 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allgather_op.py @@ -0,0 +1,69 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveAllGather(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofgather", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_allgather", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'nranks': nranks}, + outputs={'Out': toutdata}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllGather, "allgather", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_allreduce_op.py b/python/paddle/fluid/tests/unittests/collective_allreduce_op.py new file mode 100644 index 0000000000..69bd6f9904 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_allreduce_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveAllreduce(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + reduce_type = 0 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofallreduce", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_allreduce", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'reduce_type': reduce_type}, + outputs={'Out': toutdata}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveAllreduce, "allreduce", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_broadcast_op.py b/python/paddle/fluid/tests/unittests/collective_broadcast_op.py new file mode 100644 index 0000000000..18f0485f92 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_broadcast_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveBroadcast(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + rootid = 1 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofbroadcast", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_broadcast", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'root': rootid}, + outputs={'Out': toutdata}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveBroadcast, "broadcast", 0) diff --git a/python/paddle/fluid/tests/unittests/collective_reducescatter_op.py b/python/paddle/fluid/tests/unittests/collective_reducescatter_op.py new file mode 100644 index 0000000000..3e286d7f43 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/collective_reducescatter_op.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +import signal +import time +import socket +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_collective_base import TestCollectiveRunnerBase, runtime_main + + +class TestCollectiveReduceScatter(TestCollectiveRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + def get_model(self, main_prog, startup_program): + ring_id = 0 + nranks = 2 + with fluid.program_guard(main_prog, startup_program): + tindata = layers.data( + name="tindata", shape=[10, 1000], dtype='float32') + toutdata = main_prog.current_block().create_var( + name="outofrs", + dtype='float32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + main_prog.global_block().append_op( + type="c_reducescatter", + inputs={'X': tindata}, + attrs={'ring_id': ring_id, + 'nranks': nranks}, + outputs={'Out': toutdata}) + main_prog.global_block().append_op( + type="c_sync_comm_stream", + inputs={'X': toutdata}, + outputs={'Out': toutdata}, + attrs={'ring_id': ring_id}) + return toutdata + + +if __name__ == "__main__": + runtime_main(TestCollectiveReduceScatter, "reduce_scatter", 0) diff --git a/python/paddle/fluid/tests/unittests/test_allgather.py b/python/paddle/fluid/tests/unittests/test_allgather.py new file mode 100644 index 0000000000..877ae6f6e1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_allgather.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestAllGatherOp(TestDistBase): + def _setup_config(self): + pass + + def test_allgather(self, col_type="allgather"): + self.check_with_place("collective_allgather_op.py", col_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_allreduce.py b/python/paddle/fluid/tests/unittests/test_allreduce.py new file mode 100644 index 0000000000..e0b6422a67 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_allreduce.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestAllReduceOp(TestDistBase): + def _setup_config(self): + pass + + def test_allreduce(self, col_type="allreduce"): + self.check_with_place("collective_allreduce_op.py", col_type) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_broadcast.py b/python/paddle/fluid/tests/unittests/test_broadcast.py new file mode 100644 index 0000000000..029e881d6f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_broadcast.py @@ -0,0 +1,31 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestCBroadcastOp(TestDistBase): + def _setup_config(self): + pass + + def test_broadcast(self): + self.check_with_place("collective_broadcast_op.py", "broadcast") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_collective_base.py b/python/paddle/fluid/tests/unittests/test_collective_base.py new file mode 100644 index 0000000000..e0789178b3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_collective_base.py @@ -0,0 +1,273 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import unittest +import time +import argparse +import os +import six +import sys +import subprocess +import traceback +import functools +import pickle +from contextlib import closing +from six import string_types +import paddle.fluid as fluid +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core + + +class TestCollectiveRunnerBase(object): + def get_model(self, train_prog, startup_prog): + raise NotImplementedError( + "get model should be implemented by child class.") + + def wait_server_ready(self, endpoints): + assert not isinstance(endpoints, string_types) + while True: + all_ok = True + not_ready_endpoints = [] + for ep in endpoints: + ip_port = ep.split(":") + with closing( + socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex((ip_port[0], int(ip_port[1]))) + if result != 0: + all_ok = False + not_ready_endpoints.append(ep) + if not all_ok: + sys.stderr.write("server not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str( + not_ready_endpoints) + "\n") + sys.stderr.flush() + time.sleep(3) + else: + break + +#endpoints should be ["ip1:port1","ip2:port2"] + + def initCommunicator(self, program, rank, nranks, wait_port, + current_endpoint, endpoints): + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + if rank == 0 and wait_port: + self.wait_server_ready(other_endpoints) + block = program.global_block() + nccl_id_var = block.create_var( + name=nameGen.generate('nccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': nccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints + }) + + block.append_op( + type='c_comm_init', + inputs={'X': nccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': self.global_ring_id + }) + + def run_trainer(self, args): + train_prog = fluid.Program() + startup_prog = fluid.Program() + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + self.initCommunicator(startup_prog, rank, nranks, True, + current_endpoint, endpoints) + result = self.get_model(train_prog, startup_prog) + device_id = int(os.getenv("FLAGS_selected_gpus", "0")) + place = fluid.CUDAPlace( + device_id) #if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + np.random.seed(os.getpid()) + indata = np.random.random((10, 1000)) + out = exe.run(train_prog, + feed={'tindata': indata}, + fetch_list=[result.name]) + if six.PY2: + print(pickle.dumps(out)) + else: + sys.stdout.buffer.write(pickle.dumps(out)) + + +def runtime_main(test_class, col_type, sub_type): + args = {} + model = test_class() + args["deviceid"] = os.getenv("FLAGS_selected_gpus") + args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID")) + args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM")) + args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS') + args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT") + args["col_type"] = col_type + model.run_trainer(args) + + +import paddle.compat as cpt +import socket +from contextlib import closing + + +class TestDistBase(unittest.TestCase): + def setUp(self): + self._port_set = set() + self._trainers = 2 + self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + self._python_interp = sys.executable + + def _find_free_port(self): + def __free_port(): + with closing(socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in self._port_set: + self._port_set.add(port) + return port + + def _run_cluster(self, model_file, envs): + worker_endpoints = self._ps_endpoints.split(",") + w0_ep, w1_ep = worker_endpoints + #print("w0_ep:",w0_ep," w1_ep:",w1_ep) + env0 = { + "FLAGS_selected_gpus": "2", + "PADDLE_TRAINER_ID": "0", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w0_ep + } + + env1 = { + "FLAGS_selected_gpus": "3", + "PADDLE_TRAINER_ID": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w1_ep + } + #update environment + env0.update(envs) + env1.update(envs) + tr_cmd = "%s %s" + tr0_cmd = tr_cmd % (self._python_interp, model_file) + tr1_cmd = tr_cmd % (self._python_interp, model_file) + tr0_pipe = open("/tmp/tr0_err.log", "wb") + tr1_pipe = open("/tmp/tr1_err.log", "wb") + #print(tr0_cmd) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0) + + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1) + + tr0_out, tr0_err = tr0_proc.communicate() + tr1_out, tr1_err = tr1_proc.communicate() + sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) + sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + # close trainer file + tr0_pipe.close() + tr1_pipe.close() + return pickle.loads(tr0_out), pickle.loads( + tr1_out), tr0_proc.pid, tr1_proc.pid + + def check_with_place(self, + model_file, + col_type, + check_error_log=False, + need_envs={}): + required_envs = { + "FLAGS_fraction_of_gpu_memory_to_use": "0.15", + "FLAGS_eager_delete_tensor_gb": "0.0", + "PATH": os.getenv("PATH"), + "PYTHONPATH": os.getenv("PYTHONPATH", ""), + "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), + "LD_PRELOAD": os.getenv("LD_PRELOAD", ""), + "GLOG_v": "0", + "NCCL_P2P_DISABLE": "1" + } + required_envs.update(need_envs) + if check_error_log: + required_envs["GLOG_v"] = "3" + required_envs["GLOG_logtostderr"] = "1" + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, + required_envs) + np.random.seed(pid0) + input1 = np.random.random((10, 1000)) + np.random.seed(pid1) + input2 = np.random.random((10, 1000)) + if col_type == "allgather": + need_result = np.vstack((input1, input2)) + self.assertTrue(np.allclose(tr0_out, need_result)) + self.assertTrue(np.allclose(tr1_out, need_result)) + elif col_type == "broadcast": + need_result = input2 + self.assertTrue(np.allclose(tr0_out, need_result)) + self.assertTrue(np.allclose(tr1_out, need_result)) + elif col_type == "allreduce": + need_result = input1 + input2 + self.assertTrue( + np.allclose( + tr0_out, need_result, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out, need_result, rtol=1e-05, atol=1e-05)) + elif col_type == "reduce_scatter": + tmp = input1 + input2 + need_result1 = tmp[0:tmp.shape[0] // 2] + need_result2 = tmp[tmp.shape[0] // 2:] + self.assertTrue( + np.allclose( + tr0_out, need_result1, rtol=1e-05, atol=1e-05)) + self.assertTrue( + np.allclose( + tr1_out, need_result2, rtol=1e-05, atol=1e-05)) + elif col_type == "reduce_slicegather": + slicesize = input1.shape[0] // 2 + tmp10 = input1[0:slicesize] + tmp11 = input2[0:slicesize] + need_result1 = np.concatenate((tmp10, tmp11), axis=1) + tmp20 = input1[slicesize:] + tmp21 = input2[slicesize:] + need_result2 = np.concatenate((tmp20, tmp21), axis=1) + self.assertTrue(np.allclose(tr0_out, need_result1)) + self.assertTrue(np.allclose(tr1_out, need_result2)) + else: + pass diff --git a/python/paddle/fluid/tests/unittests/test_reducescatter.py b/python/paddle/fluid/tests/unittests/test_reducescatter.py new file mode 100644 index 0000000000..58bcc11cd8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reducescatter.py @@ -0,0 +1,32 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +from test_collective_base import TestDistBase + + +class TestReduceScatterOp(TestDistBase): + def _setup_config(self): + pass + + def test_reducescatter(self): + self.check_with_place("collective_reducescatter_op.py", + "reduce_scatter") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py new file mode 100644 index 0000000000..df5cdbc104 --- /dev/null +++ b/python/paddle/fluid/transpiler/collective.py @@ -0,0 +1,343 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import math +from functools import reduce + +import collections +import six +import logging + +import numpy as np + +from .. import core, unique_name +from ..framework import Program, default_main_program, default_startup_program +from .details import wait_server_ready + +__all__ = ['GradAllReduce', 'LocalSGD'] + +OpRole = core.op_proto_and_checker_maker.OpRole + + +class Collective(object): + ''' + ''' + + def __init__(self): + self.global_ring_id = 0 + self.endpoints = None + self.current_endpoint = None + self.nranks = None + self.rank = None + self.startup_program = None + self.main_program = None + op_maker = core.op_proto_and_checker_maker + self.op_role_key = op_maker.kOpRoleAttrName() + self.op_role_var_key = op_maker.kOpRoleVarAttrName() + + def transpile(self, startup_program, main_program, rank, endpoints, + current_endpoint, wait_port): + # in case of '127.0.0.1:6700,127.0.0.1:6701,...' + if isinstance(endpoints, str): + endpoints = endpoints.split(',') + + self.startup_program = startup_program + if startup_program is None: + self.startup_program = default_startup_program() + + self.main_program = main_program + if main_program is None: + self.main_program = default_main_program() + + self.nranks = len(endpoints) + if self.nranks == 1: + raise ValueError('the number of endpoints must > 1') + + if rank < 0: + raise ValueError('rank must >= 0') + self.rank = rank + + if current_endpoint not in endpoints: + raise ValueError('current endpoint %s is not in %s', + current_endpoint, str(endpoints)) + + self.endpoints = endpoints + self.current_endpoint = current_endpoint + + self.wait_port = wait_port + + self.startup_program._origin_program = self.startup_program.clone() + self._transpile_startup_program() + + self.main_program._origin_program = self.main_program.clone() + self._transpile_main_program() + + def _transpile_main_program(self): + raise NotImplementedError('call the inherited method of subclasses') + + def _transpile_startup_program(self): + self._init_communicator(self.startup_program, self.current_endpoint, + self.endpoints, self.rank, self.global_ring_id, + self.wait_port) + self._broadcast_params() + + def _init_communicator(self, program, current_endpoint, endpoints, rank, + ring_id, wait_port): + nranks = len(endpoints) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + if rank == 0 and wait_port: + wait_server_ready(other_endpoints) + + block = program.global_block() + nccl_id_var = block.create_var( + name=unique_name.generate('nccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': nccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + self.op_role_key: OpRole.Collective + }) + block.append_op( + type='c_comm_init', + inputs={'X': nccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + self.op_role_key: OpRole.Collective + }) + + def _broadcast_params(self): + block = self.startup_program.global_block() + for var in block.iter_parameters(): + block.append_op( + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': self.global_ring_id, + 'root': 0, + self.op_role_key: OpRole.Collective + }) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': self.global_ring_id, + self.op_role_key: OpRole.Collective + }) + + def _is_loss_grad_op(self, op): + if self.op_role_key not in op.attr_names: + return False + op_role = int(op.all_attrs()[self.op_role_key]) + return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss) + + def _is_backward_op(self, op): + return self.op_role_key in op.attr_names and \ + int(op.all_attrs()[self.op_role_key]) & int(OpRole.Backward) + + def _is_update_op(self, op): + return 'Param' in op.input_names and 'Grad' in op.input_names and \ + "LearningRate" in op.input_names + + def _is_optimizer_op(self, op): + return self.op_role_key in op.attr_names and \ + int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize) + + +class GradAllReduce(Collective): + ''' + ''' + + def __init__(self): + Collective.__init__(self) + + def _transpile_main_program(self): + self._insert_scale_loss_grad_ops() + self._insert_allreduce_ops() + + def _insert_scale_loss_grad_ops(self): + ''' + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + ''' + block = self.main_program.global_block() + for idx, op in reversed(list(enumerate(block.ops))): + if self._is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / self.nranks, + self.op_role_key: OpRole.Collective + }) + + def _insert_allreduce_ops(self): + block = self.main_program.global_block() + for idx, op in reversed(list(enumerate(block.ops))): + if self._is_backward_op(op) and \ + self.op_role_var_key in op.attr_names: + op_role_var = op.all_attrs()[self.op_role_var_key] + + if len(op_role_var) == 0: + continue + + assert len(op_role_var) % 2 == 0 + + block._insert_op( + idx + 1, + type='c_sync_calc_stream', + inputs={'X': block.vars[grad]}, + outputs={'Out': block.vars[grad]}, + attrs={self.op_role_key: OpRole.Collective}) + + offset = 2 + for i in range(0, len(op_role_var), 2): + grad = op_role_var[i + 1] + block._insert_op( + idx + offset, + type='c_allreduce', + inputs={'X': [block.vars[grad]]}, + outputs={'Out': [block.vars[grad]]}, + attrs={ + 'reduce_type': 0, + self.op_role_key: OpRole.Collective + }) + offset += 1 + + for idx, op in enumerate(block.ops): + if self._is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': block.vars[grad]}, + outputs={'Out': block.vars[grad]}, + attrs={ + 'ring_id': self.global_ring_id, + self.op_role_key: OpRole.Collective + }) + break + + +class LocalSGD(Collective): + ''' + ''' + + def __init__(self): + Collective.__init__(self) + self.snapshot_key = '@SNAPSHOT' + + def _transpile_startup_program(self): + Collective._transpile_startup_program(self) + + block = self.startup_program.global_block() + for param in block.iter_parameters(): + snapshot = block.create_var( + name=self.snapshot_name(param.name), + shape=param.shape, + persistable=True, + stop_gradient=True) + block.append_op( + type='assign', + inputs={'X': [param]}, + outputs={'Out': [snapshot]}, + attrs={self.op_role_key: OpRole.Collective}) + + def snapshot_name(self, param_name): + return param_name + self.snapshot_key + + def _transpile_main_program(self): + block = self.main_program.global_block() + ordered_param_snapshot = [] + for idx, op in reversed(list(enumerate(block.ops))): + if self._is_update_op(op): + param = block.vars[op.input('Param')[0]] + snapshot = block.create_var( + name=self.snapshot_name(param.name), + shape=param.shape, + persistable=True, + stop_gradient=True) + + block._insert_op( + idx + 1, + type='elementwise_sub', + inputs={'X': [snapshot], + 'Y': [param]}, + outputs={'Out': [param]}, + attrs={self.op_role_key: OpRole.Collective}) + block._insert_op( + idx + 2, + type='c_sync_calc_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={self.op_role_key: OpRole.Collective}) + block._insert_op( + idx + 3, + type='c_allreduce', + inputs={'X': [param]}, + outputs={'Out': [param]}, + attrs={ + 'reduce_type': 0, + self.op_role_key: OpRole.Collective + }) + + ordered_param_snapshot.append((param, snapshot)) + + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self.global_ring_id, + self.op_role_key: OpRole.Collective + }) + + for param_snapshot in reversed(ordered_param_snapshot): + param = param_snapshot[0] + snapshot = param_snapshot[1] + block.append_op( + type='scale', + inputs={'X': [param]}, + outputs={'Out': [param]}, + attrs={ + 'scale': 1.0 / self.nranks, + self.op_role_key: OpRole.Collective + }) + block.append_op( + type='elementwise_sub', + inputs={'X': [snapshot], + 'Y': [param]}, + outputs={'Out': [param]}, + attrs={self.op_role_key: OpRole.Collective}) + block.append_op( + type='assign', + inputs={'X': [param]}, + outputs={'Out': [snapshot]}, + attrs={self.op_role_key: OpRole.Collective}) diff --git a/python/paddle/fluid/transpiler/details/program_utils.py b/python/paddle/fluid/transpiler/details/program_utils.py index 391d6aa12b..65807d1b9f 100644 --- a/python/paddle/fluid/transpiler/details/program_utils.py +++ b/python/paddle/fluid/transpiler/details/program_utils.py @@ -90,7 +90,7 @@ def variable_to_code(var): return var_str -def op_to_code(op): +def op_to_code(op, skip_op_callstack=False): """ Get readable codes of fluid operator. @@ -124,6 +124,8 @@ def op_to_code(op): attrs_str = "" for i in range(0, len(attr_names)): name = attr_names[i] + if skip_op_callstack and name == "op_callstack": + continue attr_type = op.desc.attr_type(name) if attr_type == core.AttrType.BLOCK: @@ -157,29 +159,35 @@ def op_to_code(op): return op_str -def block_to_code(block, block_idx): +def block_to_code(block, block_idx, fout=None, skip_op_callstack=False): indent = 0 - print("{0}{1} // block {2}".format( - get_indent_space(indent), '{', block_idx)) + print( + "{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx), + file=fout) indent += 1 # sort all vars all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0]) for var in all_vars: - print("{}{}".format(get_indent_space(indent), variable_to_code(var[1]))) + print( + "{}{}".format(get_indent_space(indent), variable_to_code(var[1])), + file=fout) if len(all_vars) > 0: - print("") + print("", file=fout) for op in block.ops: - print("{}{}".format(get_indent_space(indent), op_to_code(op))) + print( + "{}{}".format( + get_indent_space(indent), op_to_code(op, skip_op_callstack)), + file=fout) indent -= 1 - print("{0}{1}".format(get_indent_space(indent), '}')) + print("{0}{1}".format(get_indent_space(indent), '}'), file=fout) -def program_to_code(prog): +def program_to_code(prog, fout=None, skip_op_callstack=False): """ Print readable codes of fluid program. @@ -191,5 +199,5 @@ def program_to_code(prog): """ block_idx = 0 for block in prog.blocks: - block_to_code(block, block_idx) + block_to_code(block, block_idx, fout, skip_op_callstack) block_idx += 1 diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index feb3277382..ece7ade693 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -47,6 +47,7 @@ from ..framework import Program, default_main_program, \ from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed from .details import delete_ops, find_op_by_output_arg from ..distribute_lookup_table import find_distributed_lookup_table +from . import collective LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" @@ -157,7 +158,7 @@ class DistributeTranspilerConfig(object): split_method = None min_block_size = 8192 enable_dc_asgd = False - # supported modes: pserver, nccl2 + # supported modes: pserver, nccl2, collective mode = "pserver" print_log = False wait_port = True @@ -174,6 +175,10 @@ class DistributeTranspilerConfig(object): #Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to nodes number. hierarchical_allreduce_exter_nranks = 0 + # if mode is collective + # supported modes: sgd, local_sgd + collective_mode = None + class DistributeTranspiler(object): """ @@ -305,6 +310,46 @@ class DistributeTranspiler(object): else: raise ValueError("must set trainer_id > 0") + def _transpile_collective(self, + collective_mode, + trainer_id, + trainers, + current_endpoint, + startup_program=None, + main_program=None, + wait_port=True): + if isinstance(trainers, str): + endpoints = trainers.split(",") + elif isinstance(trainers, list): + endpoints = trainers + else: + raise ValueError('invalid trainers config: ' + str(trainers)) + + if len(endpoints) == 1: + raise ValueError('invalid trainer number in distributed: 1') + + if startup_program is None: + startup_program = default_startup_program() + + if main_program is None: + main_program = default_main_program() + + transpiler = None + if collective_mode == 'grad_allreduce': + transpiler = collective.GradAllReduce() + elif collective_mode == 'local_sgd': + transpiler = collective.LocalSGD() + else: + raise ValueError('invalid collective_mode: %s' % collective_mode) + + transpiler.transpile( + startup_program=startup_program, + main_program=main_program, + rank=trainer_id, + endpoints=endpoints, + current_endpoint=current_endpoint, + wait_port=wait_port) + def _get_all_remote_sparse_update_op(self, main_program): sparse_update_ops = [] sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"] @@ -395,6 +440,17 @@ class DistributeTranspiler(object): wait_port=self.config.wait_port) return + if self.config.mode == "collective": + self._transpile_collective( + collective_mode=self.config.collective_mode, + trainer_id=trainer_id, + trainers=trainers, + current_endpoint=current_endpoint, + startup_program=startup_program, + main_program=program, + wait_port=self.config.wait_port) + return + self.trainer_num = trainers self.sync_mode = sync_mode self.trainer_id = trainer_id -- GitLab