提交 b7128bac 编写于 作者: H HaoRen 提交者: Yi Liu

supports collective communicated training (#18175)

* fix prepare context redundant code problem, optimize executor by caching create_varaiables
test=develop

* supports collective training in executor

* make fetch_list runable with variables, add more unittest for use_program_cache
test=develop

* fix comment
test=develop

* use unique name for nccl_id

* supports output to stream in program_to_code

* insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code

* set op role in collective training

* add collective op role

* remove orig file

* add build optimizer by strategy

* add collective strategy

* refine collective strategy

* add multi-process role maker

* refine strategy building factory so that we can easily plugin more strategy

* scale loss grad in collective sgd transpiler

* add support for distributed fc

* code format

* revert some features for dist fc

* add support for distributed fc training

* fix prepare context redundant code problem, optimize executor by caching create_varaiables
test=develop

* supports collective training in executor

* make fetch_list runable with variables, add more unittest for use_program_cache
test=develop

* use unique name for nccl_id

* supports output to stream in program_to_code

* insert sync_comm_stream before regularization; add skip_op_callstack capability in program_to_code

* set op role in collective training

* add collective op role

* fix comment
test=develop

* remove orig file

* add build optimizer by strategy

* add collective strategy

* refine collective strategy

* add multi-process role maker

* refine strategy building factory so that we can easily plugin more strategy

* scale loss grad in collective sgd transpiler

* add support for distributed fc

* code format

* revert some features for dist fc

* add support for distributed fc training

* test=develop
add collective op unittest standard

* test=develop
remove the test_collective directory

* test=develop
remove the test_collective directory

* remove slicegather test

* code format for reducescatter

* update attr of shard_index_op

* Modify macro nccl_helper

* remove test without distribute

* macro collective_helper

* marcro update

* test=develop
update support python3.5

* test=develop change gpu memory use to 0.1 when test

* test=develop
update ut equal func

* test=develop
set flags to 1.5

* test=develop fix pickle dumple  py35

* test=develop
fix divide in slice and add sync_comm_stream
update atol and rtol to 1e-05
rm shard_index op and test
modify read input from file to read from memory
remove origin_program in framework and add i/o in c_sync_calc_stream

* test=develop update unittest sync operator I/O
上级 9252e8fa
......@@ -179,7 +179,7 @@ if(WITH_DISTRIBUTE)
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS}
lod_rank_table feed_fetch_method sendrecvop_rpc collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto ${NGRAPH_EXE_DEPS} timer)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
......@@ -73,6 +74,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kCollective),
static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
......
......@@ -34,6 +34,9 @@ enum class OpRole {
kDist = 0x0008,
// Tag all learning rate scheduler operators.
kLRSched = 0x0010,
// Collective role is for all collective operators and other operators used
// for collective training
kCollective = 0x0020,
kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and
......
......@@ -21,6 +21,7 @@ add_subdirectory(jit)
if(WITH_DISTRIBUTE)
add_subdirectory(distributed)
add_subdirectory(distributed_ops)
add_subdirectory(collective)
endif()
add_subdirectory(reader)
......
include(operators)
set(COLLECTIVE_DEPS "")
if(WITH_GRPC)
set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(COLLECTIVE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET ibverbs PROPERTY IMPORTED_LOCATION ${IBVERBS_LIBRARY})
find_library(RDMACM_LIBRARY NAMES rdmacm)
ADD_LIBRARY(rdmacm SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET rdmacm PROPERTY IMPORTED_LOCATION ${RDMACM_LIBRARY})
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} ibverbs rdmacm)
endif()
endif()
set(COLLECTIVE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)
foreach(src ${OPS})
set_source_files_properties(${src} PROPERTIES COMPILE_FLAGS ${COLLECTIVE_COMPILE_FLAGS})
endforeach()
register_operators(EXCLUDES c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS})
if(WITH_GPU AND NOT WIN32)
set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper)
op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS} nccl_common)
endif()
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COLLECTIVE_DEPS} PARENT_SCOPE)
set(GLOB_COLLECTIVE_DEPS ${COLLECTIVE_DEPS} CACHE INTERNAL "collective dependency")
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
#include <future> // NOLINT
#include <memory>
#include <ostream>
namespace paddle {
namespace operators {
class CAllGatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SyncFCGather op should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks");
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2");
framework::DDim dim = ctx->GetInputDim("X");
dim[0] = dim[0] * nranks;
ctx->SetOutputDim("Out", dim);
}
};
class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be allgather");
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job");
AddComment(R"DOC(
***CAllGather Operator***
each rank receives the aggregation of data from all ranks in the order of the ranks
Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather
)DOC");
}
};
class CAllGatherOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("c_reducescatter");
retv->SetInput("X", OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker,
ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CPUDeviceContext, float>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, double>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allgather, ops::CAllGatherOpKernel<plat::CUDADeviceContext, float>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, double>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CAllGatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllGatherOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
class CAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the allreduced result.");
AddAttr<int>("reduce_type", "(int default 0) determin the reduce type.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CAllReduce Operator***
Call NCCL collective AllReduce internally. Note that this op must be used when one
thread is managing one GPU device.
For speed reasons, reduce_type should be an integer:
0: sum
1: prod
2: max
3: min
If input and output are the same variable, in-place allreduce will be used.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce, ops::CAllReduceOp,
ops::CAllReduceOpMaker);
REGISTER_OP_CPU_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CPUDeviceContext, float>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, double>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CAllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
const void* sendbuff = in->data<void>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, red_type, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
namespace paddle {
namespace operators {
class CBroadcastOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be broadcasted.");
AddOutput("Out", "(Tensor) the result of broadcast.");
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
.SetDefault(0);
AddAttr<int>("root", "(int default 0) root id for broadcasting.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CBroadcast Operator***
Call ncclBcast internally.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp,
ops::CBroadcastOpMaker);
REGISTER_OP_CPU_KERNEL(
c_broadcast, ops::CBroadcastOpKernel<plat::CPUDeviceContext, float>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, double>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_broadcast, ops::CBroadcastOpKernel<plat::CUDADeviceContext, float>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, double>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CBroadcastOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CBroadcastOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
int nranks = comm->nranks();
PADDLE_ENFORCE(root >= 0 && root < nranks,
"Expected root in range of [0,%d),but get %d", nranks, root);
if (root == comm->rank()) {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
// TODO(liuyi05): check inplace
framework::TensorCopy(
*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel, dtype, root,
comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CCommInitOp : public framework::OperatorBase {
public:
CCommInitOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"CCommInitOp can run on gpu place only.");
auto var = scope.FindVar(Input("X"));
PADDLE_ENFORCE_NOT_NULL(var);
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
ncclUniqueId* nccl_id = var->GetMutable<ncclUniqueId>();
int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id");
platform::NCCLCommContext::Instance().CreateNCCLComm(
nccl_id, nranks, rank_id, boost::get<platform::CUDAPlace>(place).device,
rid);
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CCommInitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CCommInit operator
Initialize collective communicatoin context within this trainer
)DOC");
AddAttr<int>("nranks", "(int) The number of ranks of distributed trainers");
AddAttr<int>("rank",
"(int) The rank of the trainer in distributed training.");
AddAttr<int>("ring_id", "(int default 0) user specified ring id")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_comm_init, ops::CCommInitOp, ops::CCommInitOpMaker);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CGenNCCLIdOp : public framework::OperatorBase {
public:
CGenNCCLIdOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int rank = Attr<int>("rank");
framework::Scope& local_scope = scope.NewScope();
if (rank == 0) {
GenerateAndSend(&local_scope, dev_ctx);
} else {
GetIdByServer(&local_scope, dev_ctx);
}
scope.DeleteScope(&local_scope);
}
private:
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string var_name = Output("Out");
auto var = scope->FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(var);
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, var_name);
}
client->Wait();
for (auto& ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(framework::Scope* scope,
const platform::DeviceContext& dev_ctx) const {
std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(true);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
}
};
class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
AddAttr<std::string>("endpoint",
"(string), e.g. 127.0.0.1:6175 "
"current listen endpoint");
AddAttr<std::vector<std::string>>(
"other_endpoints",
"['trainer1_ip:port', 'trainer2_ip:port', ...] "
"list of other trainer endpoints")
.SetDefault({});
AddAttr<int>("rank",
"(int default 0) "
"The rank of the trainer in distributed training.")
.SetDefault(0);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_gen_nccl_id, ops::CGenNCCLIdOp, ops::CGenNCCLIdOpMaker);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include <future> // NOLINT
#include <memory>
#include <ostream>
namespace paddle {
namespace operators {
class CReduceScatterOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks");
framework::DDim dim = ctx->GetInputDim("X");
if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE(dim[0] % nranks == 0,
"dim[0] (%d) is not divisible by nranks(%d)", dim[0],
nranks);
dim[0] /= nranks;
}
ctx->SetOutputDim("Out", dim);
}
};
class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) tensor to be allgather");
AddOutput("Out", "(Tensor) the allgather result");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<int>("nranks",
"Total trainer count of the distributed training job")
.SetDefault(1);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CReduceScatter Operator***
Call NCCL collective ReduceScatter internally.
)DOC");
}
};
class CReduceScatterOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
retv->SetType("c_allgather");
retv->SetInput("X", OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp,
ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL(
c_reducescatter, ops::CReduceScatterOpKernel<plat::CPUDeviceContext, float>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, double>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int64_t>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_reducescatter,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, float>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, double>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class CReduceScatterOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
auto out_dims = in->dims();
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks;
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
int dtype = platform::ToNCCLDataType(in->type());
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclReduceScatter(
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle {
namespace operators {
class CSyncCalcStreamOp : public framework::OperatorBase {
public:
CSyncCalcStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
cudaError_t e_sync = cudaStreamSynchronize(dev_ctx->stream());
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync cuda stream: " << cudaGetErrorString(e_sync);
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync");
AddComment(R"DOC(
***Sync Operator***
Call cuda stream synchronize.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_calc_stream, ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h>
#endif
#include <stdint.h>
#include <ostream>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
class CSyncCommStreamOp : public framework::OperatorBase {
public:
CSyncCommStreamOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
int ring_id = Attr<int>("ring_id");
auto stream = platform::NCCLCommContext::Instance().Get(ring_id)->stream();
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "Fail to sync nccl stream: " << cudaGetErrorString(e_sync);
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync");
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC(
***Sync Operator***
Call nccl stream synchronize.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(c_sync_comm_stream, ops::CSyncCommStreamOp,
ops::CSyncCommStreamOpMaker);
......@@ -21,7 +21,6 @@ endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
list(REMOVE_DUPLICATES OPS)
......
......@@ -86,6 +86,11 @@ class OneHotOpMaker : public framework::OpProtoAndCheckerMaker {
"An integer to specify the data type of one-hot "
"vector. The default value is FP32.")
.SetDefault(paddle::framework::proto::VarType::FP32);
AddAttr<bool>("allow_out_of_range",
"If it is set true and the input data is out of range, "
"the output tensor will be filled zeros. The default value "
"is false.")
.SetDefault(false);
AddComment(R"DOC(
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
......
......@@ -24,7 +24,7 @@ template <typename InT, typename OutT>
__global__ void FillOutputKernel(const InT* p_in_data, OutT* p_out_data,
const int64_t numel, const int depth) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numel) {
if (idx < numel && p_in_data[idx] >= 0 && p_in_data[idx] < depth) {
*(p_out_data + (idx * depth) + p_in_data[idx]) = 1.0;
}
}
......
......@@ -25,10 +25,16 @@ struct OneHotOpFunctor {
framework::LoDTensor* out_;
int depth_;
const DeviceContext& ctx_;
bool allow_out_of_range_;
OneHotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out,
int depth, const DeviceContext& ctx)
: in_(in), out_(out), depth_(depth), ctx_(ctx) {}
int depth, const DeviceContext& ctx,
bool allow_out_of_range = false)
: in_(in),
out_(out),
depth_(depth),
ctx_(ctx),
allow_out_of_range_(allow_out_of_range) {}
template <typename OutT>
void apply() const {
......@@ -37,13 +43,21 @@ struct OneHotOpFunctor {
auto* p_out_data = out_->mutable_data<OutT>(ctx_.GetPlace());
math::set_constant(ctx_, out_, 0.0);
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(p_in_data[i], 0,
"Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(p_in_data[i], depth_,
"Illegal index value, should be less than depth (%d).",
depth_);
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
if (allow_out_of_range_) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth_) {
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
PADDLE_ENFORCE_GE(p_in_data[i], 0,
"Illegal index value, should be at least 0.");
PADDLE_ENFORCE_LT(
p_in_data[i], depth_,
"Illegal index value, should be less than depth (%d).", depth_);
*(p_out_data + i * depth_ + p_in_data[i]) = 1.0;
}
}
}
};
......@@ -57,6 +71,7 @@ class OneHotKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
bool allow_out_of_range = context.Attr<bool>("allow_out_of_range");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
......@@ -71,7 +86,8 @@ class OneHotKernel : public framework::OpKernel<T> {
static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype")),
OneHotOpFunctor<DeviceContext, T>(
in, out, depth, context.template device_context<DeviceContext>()));
in, out, depth, context.template device_context<DeviceContext>(),
allow_out_of_range));
}
};
......
......@@ -74,6 +74,10 @@ cc_library(device_context SRCS device_context.cc init.cc DEPS simple_threadpool
place eigen3 stringpiece cpu_helper cpu_info framework_proto ${GPU_CTX_DEPS} ${MKLDNN_CTX_DEPS}
temp_allocator ${dgc_deps})
if (WITH_DISTRIBUTE)
cc_library(collective_helper SRCS collective_helper.cc DEPS framework_proto)
endif()
if(WIN32)
if(WITH_GPU AND NOT WITH_DSO)
get_property(cuda_modules GLOBAL PROPERTY CUDA_MODULES)
......@@ -97,7 +101,7 @@ cc_test(lodtensor_printer_test SRCS lodtensor_printer_test.cc DEPS lodtensor_pri
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto framework_proto ${GPU_CTX_DEPS})
if(WITH_GPU)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce)
nv_library(profiler SRCS profiler.cc profiler.cu DEPS device_tracer gpu_info enforce)
else()
cc_library(profiler SRCS profiler.cc DEPS device_tracer enforce)
endif()
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #ifndef _WIN32
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include <functional>
#include "paddle/fluid/platform/dynload/nccl.h"
namespace paddle {
namespace platform {
class NCCLCommImpl : public NCCLComm {
public:
void set_ring_id(int ring_id) { ring_id_ = ring_id; }
int ring_id() const override { return ring_id_; }
void set_nranks(int nranks) { nranks_ = nranks; }
int nranks() const override { return nranks_; }
void set_rank(int rank) { rank_ = rank; }
int rank() const override { return rank_; }
void set_local_rank(int local_rank) { local_rank_ = local_rank; }
int local_rank() const override { return local_rank_; }
void set_comm(ncclComm_t comm) { comm_ = comm; }
ncclComm_t comm() const override { return comm_; }
void set_dev_ctx(CUDADeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
CUDADeviceContext* DevCtx() const override { return dev_ctx_; }
cudaStream_t stream() const override { return dev_ctx_->stream(); }
private:
int ring_id_;
int nranks_;
int rank_;
int local_rank_;
ncclComm_t comm_;
CUDADeviceContext* dev_ctx_;
};
// NOTE: not thread-safe
NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
int rank, int dev_id, int ring_id) {
PADDLE_ENFORCE_NOT_NULL(nccl_id);
PADDLE_ENFORCE_GT(nranks, 1);
PADDLE_ENFORCE(rank >= 0 && rank < nranks,
"Expected rank id range [0, %d), but get %d", nranks, rank);
PADDLE_ENFORCE_GE(dev_id, 0);
if (dev_ctx_map_.count(dev_id) == 0) {
dev_ctx_map_.emplace(dev_id, std::unique_ptr<CUDADeviceContext>(
new CUDADeviceContext(CUDAPlace(dev_id))));
}
ncclComm_t comm = nullptr;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
NCCLCommImpl* communicator = new NCCLCommImpl;
communicator->set_ring_id(ring_id);
communicator->set_nranks(nranks);
communicator->set_rank(rank);
communicator->set_local_rank(dev_id);
communicator->set_comm(comm);
communicator->set_dev_ctx(dev_ctx_map_.at(dev_id).get());
comm_map_.emplace(ring_id, std::unique_ptr<NCCLComm>(communicator));
VLOG(0) << "nccl communicator of rank " << rank << " in ring " << ring_id
<< " has been created";
return comm_map_.at(ring_id).get();
}
NCCLCommContext::~NCCLCommContext() {
for (auto& p : comm_map_) {
PADDLE_ENFORCE(platform::dynload::ncclCommDestroy(p.second->comm()));
}
}
} // namespace platform
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// #ifndef _WIN32
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
// In order to apply hierarchical communication with NCCL, we need
// a communication ring contains NCCL communicators associated to a global
// ncclUniqueId. E.g. for a hierarchical case,
//
// 11 - 12 21 - 22
// | | | |
// 13 - 14 - 23 - 24
// | |
// 31 - 32 - 41 - 42
// | | | |
// 33 - 34 43 - 44
//
// we group (14,23,32,41) as the top, and (11,12,13,14), (21,22,23,24),
// (31,32,33,34), (41,42,43,44) as bottoms respectively.
//
// We could also use a single communication ring for the flatten case
//
// The NCCLComm instance is created and reversed in the NCCLCommContext
// singleton with a global user specified group id.
class NCCLComm {
public:
virtual int ring_id() const = 0;
virtual int nranks() const = 0;
virtual int rank() const = 0;
virtual int local_rank() const = 0;
virtual ncclComm_t comm() const = 0;
virtual cudaStream_t stream() const = 0;
virtual CUDADeviceContext* DevCtx() const = 0;
virtual ~NCCLComm() = default;
};
// a singleton NCCL communicator context reserves communication ring ids
// Assume multiprocessing mode
class NCCLCommContext {
public:
static NCCLCommContext& Instance() {
static NCCLCommContext comm_ctx;
return comm_ctx;
}
~NCCLCommContext();
NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank,
int dev_id, int ring_id = 0);
CUDADeviceContext* DevCtx(int dev_id) const {
PADDLE_ENFORCE(dev_ctx_map_.count(dev_id),
"CUDADeviceContext at device %d has not been initialized");
return dev_ctx_map_.at(dev_id).get();
}
CUDADeviceContext* DevCtx(platform::Place p) const {
return DevCtx(boost::get<CUDAPlace>(p).device);
}
// retrieve a communicator by the ring id
NCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE(comm_map_.count(ring_id),
"comunicator in ring id %d has not been initialized",
ring_id);
return comm_map_.at(ring_id).get();
}
private:
// ring id to NCCLComm
std::unordered_map<int, std::unique_ptr<NCCLComm>> comm_map_;
// device id to CUDADeviceContext
std::unordered_map<int, std::unique_ptr<CUDADeviceContext>> dev_ctx_map_;
NCCLCommContext() = default;
NCCLCommContext(const NCCLCommContext& other) = delete;
NCCLCommContext& operator=(const NCCLCommContext& other) = delete;
};
} // namespace platform
} // namespace paddle
#endif
......@@ -66,6 +66,7 @@ extern void* nccl_dso_handle;
__macro(ncclGroupStart); \
__macro(ncclGroupEnd); \
__macro(ncclReduce); \
__macro(ncclReduceScatter); \
__macro(ncclGetErrorString);
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
......
......@@ -46,6 +46,7 @@ void BindConstValue(pybind11::module* m) {
.value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC)
.value("Dist", framework::OpRole::kDist)
.value("Collective", framework::OpRole::kCollective)
.value("LRSched", framework::OpRole::kLRSched);
op_proto_and_checker_maker.def(
......
......@@ -1005,7 +1005,9 @@ class Operator(object):
OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id'
'ncclInit', 'select', 'checkpoint_notify', 'gen_nccl_id',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_calc_stream',
'c_sync_comm_stream'
}
def __init__(self,
......
......@@ -190,7 +190,6 @@ class Fleet(object):
self._role_maker = role_maker
self._role_maker.generate_role()
self._is_initialized = True
@abc.abstractmethod
......
......@@ -103,6 +103,46 @@ class RoleMakerBase(object):
return self._server_endpoints
class MultiProcessRoleMaker(RoleMakerBase):
"""
MultiProcessRoleMaker is a default role maker for multi-process
GPU training. It works with paddle.distributed.lanuch.py by-design
"""
def __init__(self):
super(MultiProcessRoleMaker, self).__init__()
self._role_is_generated = False
def generate_role(self):
import os
if not self._role_is_generated:
self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._num_trainers = 1
self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert (self._training_role == "TRAINER")
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
self._current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
if self._worker_endpoints:
self._worker_endpoints = self._worker_endpoints.split(",")
self._num_trainers = len(self._worker_endpoints)
self._role_is_generated = True
def is_worker(self):
return True
def is_server(self):
return False
def is_first_worker(self):
return self._current_id == 0
def worker_index(self):
return self._current_id
def worker_num(self):
return self._worker_num
class MPIRoleMaker(RoleMakerBase):
"""
MPIRoleMaker is a MPI-API based role maker which is a counter-part of K8SRoleMaker
......
......@@ -22,6 +22,116 @@ from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
class DistributedStrategy(object):
def __init__(self):
# precision configs
self.use_fp16 = False
self.use_fp32 = True
# algorithmic communication
self.local_sgd = False
self.dgc = False
# communication topology configs
self.h_allreduce = False
def build(self):
# make sure we set single precision config True
if self.use_fp32 and self.use_fp16:
self.use_fp16 = False
# make sure we set single algorithmic communication True
if self.local_sgd and self.dgc:
self.local_sgd = False
self.strategy_map["fp16"] = self.use_fp16
self.strategy_map["fp32"] = self.use_fp32
self.strategy_map["localsgd"] = self.local_sgd
self.strategy_map["dgc"] = self.dgc
self.strategy_map["h_allreduce"] = self.h_allreduce
class DistributedOptimizerFactory(object):
def strategy_to_optimizer_map(self):
pattern = {}
pattern["fp16"] = [
"MixedPrecisionOptimizer", "MixedPrecisionLocalSGDOptimizer"
]
pattern["fp32"] = ["FullPrecisionOptimizer", "LocalSGDOptimizer"]
pattern["localsgd"] = [
"MixedPrecisionLocalSGDOptimizer", "LocalSGDOptimizer"
]
pattern["h_allreduce"] = [
"FullPrecisionOptimizer",
"LocalSGDOptimizer",
"MixedPrecisionOptimizer",
"MixedPrecisionLocalSGDOptimizer",
]
self.pattern = pattern
def create_by_strategy(self, optimizer, strategy):
if strategy == None:
strategy = DistributedStrategy()
strategy.build()
strategy_list = []
for key in strategy.strategy_map:
if strategy.strategy_map[key]:
strategy_list.append(self.pattern[key])
classname = list(set.intersection(*map(set, strategy_list)))[0]
return globals()[classname](optimizer, strategy)
class DistributedStrategy(object):
def __init__(self):
# precision configs
self.use_fp16 = False
self.use_fp32 = True
# algorithmic communication
self.local_sgd = False
self.dgc = False
# communication topology configs
self.h_allreduce = False
def build(self):
# make sure we set single precision config True
if self.use_fp32 and self.use_fp16:
self.use_fp16 = False
# make sure we set single algorithmic communication True
if self.local_sgd and self.dgc:
self.local_sgd = False
self.strategy_map["fp16"] = self.use_fp16
self.strategy_map["fp32"] = self.use_fp32
self.strategy_map["localsgd"] = self.local_sgd
self.strategy_map["dgc"] = self.dgc
self.strategy_map["h_allreduce"] = self.h_allreduce
class DistributedOptimizerFactory(object):
def strategy_to_optimizer_map(self):
pattern = {}
pattern["fp16"] = [
"MixedPrecisionOptimizer", "MixedPrecisionLocalSGDOptimizer"
]
pattern["fp32"] = ["FullPrecisionOptimizer", "LocalSGDOptimizer"]
pattern["localsgd"] = [
"MixedPrecisionLocalSGDOptimizer", "LocalSGDOptimizer"
]
pattern["h_allreduce"] = [
"FullPrecisionOptimizer",
"LocalSGDOptimizer",
"MixedPrecisionOptimizer",
"MixedPrecisionLocalSGDOptimizer",
]
self.pattern = pattern
def create_by_strategy(self, optimizer, strategy):
if strategy == None:
strategy = DistributedStrategy()
strategy.build()
strategy_list = []
for key in strategy.strategy_map:
if strategy.strategy_map[key]:
strategy_list.append(self.pattern[key])
classname = list(set.intersection(*map(set, strategy_list)))[0]
return globals()[classname](optimizer, strategy)
class Collective(Fleet):
def __init__(self):
super(Collective, self).__init__(Mode.COLLECTIVE)
......@@ -48,7 +158,8 @@ class Collective(Fleet):
"You should not call 'stop_worker' method for collective mode.")
def distributed_optimizer(self, optimizer, strategy=None):
self._optimizer = CollectiveOptimizer(optimizer, strategy)
self._optimizer = \
DistributedOptimizerFactory.create_by_strategy(optimizer, strategy)
return self._optimizer
def save_inference_model(self,
......@@ -69,6 +180,85 @@ class Collective(Fleet):
fleet = Collective()
class CollectiveOpBasedOptimizer(DistributedOptimizer):
"""
TBA
"""
def __init__(self, optimizer, strategy=None):
super(CollectiveOpBasedOptimizer, self).__init__(optimizer, strategy)
def _transpile_program(self, startup_program=None):
startup_program = startup_program if startup_program else \
fluid.framework.default_startup_program()
worker_endpoints = fleet.worker_endpoints()
trainer_id = fleet.worker_index()
current_endpoint = fleet.worker_endpoints()[trainer_id]
# call transpiler
config = dist_transpiler.DistributeTranspilerConfig()
config.mode = "collective"
config.collective_mode = "sgd"
t = dist_transpiler.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=','.join(worker_endpoints),
startup_program=startup_program,
current_endpoint=current_endpoint)
def backward(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None,
callbacks=None):
return self._optimizer.backward(loss, startup_program, parameter_list,
no_grad_set, callbacks)
def apply_gradients(self, params_grads):
return self._optimizer.apply_gradients(params_grads)
class MixedPrecisionOptimizer(CollectiveOpBasedOptimizer):
"""
TBA
"""
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
pass
class FullPrecisionOptimizer(CollectiveOpBasedOptimizer):
"""
TBA
"""
def __init__(self, optimizer, strategy=None):
super(FullPrecisionOptimizer, self).__init__(optimizer, strategy)
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._transpile_program(startup_program)
train_program = loss.block.program
param_grads = self.backward(loss)
train_program.global_block().append_op(type='c_sync_compute_stream')
data_parallel_param_grads = []
for p, g in param_grads:
# NOTE: scale will be done on loss scale
# in multi_devices_graph_pass using nranks.
reduced_g = fluid.layers.collective._allreduce(g, g)
data_parallel_param_grads.append([p, reduced_g])
train_program.global_block().append_op(type='c_sync_comm_stream')
self.apply_gradients(data_parallel_param_grads)
class CollectiveOptimizer(DistributedOptimizer):
"""
DistributedOptimizer is a wrapper for paddle.fluid.optimizer
......@@ -82,7 +272,7 @@ class CollectiveOptimizer(DistributedOptimizer):
def __init__(self, optimizer, strategy=None):
super(CollectiveOptimizer, self).__init__(optimizer, strategy)
assert strategy is None, "You cannot set 'strategy' for collective."
self.strategy = strategy
def backward(self,
loss,
......
......@@ -20,6 +20,14 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr)
endif(NOT WITH_DISTRIBUTE)
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_allgather)
LIST(REMOVE_ITEM TEST_OPS test_allreduce)
LIST(REMOVE_ITEM TEST_OPS test_broadcast)
LIST(REMOVE_ITEM TEST_OPS test_reducescatter)
endif()
LIST(REMOVE_ITEM TEST_OPS test_launch)
if (NOT ${WITH_GPU})
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
class TestCollectiveAllGather(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofgather",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_allgather",
inputs={'X': tindata},
attrs={'ring_id': ring_id,
'nranks': nranks},
outputs={'Out': toutdata})
main_prog.global_block().append_op(
type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveAllGather, "allgather", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
class TestCollectiveAllreduce(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
reduce_type = 0
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofallreduce",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_allreduce",
inputs={'X': tindata},
attrs={'ring_id': ring_id,
'reduce_type': reduce_type},
outputs={'Out': toutdata})
main_prog.global_block().append_op(
type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveAllreduce, "allreduce", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
class TestCollectiveBroadcast(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
rootid = 1
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofbroadcast",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_broadcast",
inputs={'X': tindata},
attrs={'ring_id': ring_id,
'root': rootid},
outputs={'Out': toutdata})
main_prog.global_block().append_op(
type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveBroadcast, "broadcast", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import os
import sys
import signal
import time
import socket
from contextlib import closing
from six import string_types
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
import unittest
from multiprocessing import Process
import paddle.fluid.layers as layers
from functools import reduce
from test_collective_base import TestCollectiveRunnerBase, runtime_main
class TestCollectiveReduceScatter(TestCollectiveRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program):
ring_id = 0
nranks = 2
with fluid.program_guard(main_prog, startup_program):
tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32')
toutdata = main_prog.current_block().create_var(
name="outofrs",
dtype='float32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
main_prog.global_block().append_op(
type="c_reducescatter",
inputs={'X': tindata},
attrs={'ring_id': ring_id,
'nranks': nranks},
outputs={'Out': toutdata})
main_prog.global_block().append_op(
type="c_sync_comm_stream",
inputs={'X': toutdata},
outputs={'Out': toutdata},
attrs={'ring_id': ring_id})
return toutdata
if __name__ == "__main__":
runtime_main(TestCollectiveReduceScatter, "reduce_scatter", 0)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_base import TestDistBase
class TestAllGatherOp(TestDistBase):
def _setup_config(self):
pass
def test_allgather(self, col_type="allgather"):
self.check_with_place("collective_allgather_op.py", col_type)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_base import TestDistBase
class TestAllReduceOp(TestDistBase):
def _setup_config(self):
pass
def test_allreduce(self, col_type="allreduce"):
self.check_with_place("collective_allreduce_op.py", col_type)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_base import TestDistBase
class TestCBroadcastOp(TestDistBase):
def _setup_config(self):
pass
def test_broadcast(self):
self.check_with_place("collective_broadcast_op.py", "broadcast")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import time
import argparse
import os
import six
import sys
import subprocess
import traceback
import functools
import pickle
from contextlib import closing
from six import string_types
import paddle.fluid as fluid
import paddle.fluid.unique_name as nameGen
from paddle.fluid import core
class TestCollectiveRunnerBase(object):
def get_model(self, train_prog, startup_prog):
raise NotImplementedError(
"get model should be implemented by child class.")
def wait_server_ready(self, endpoints):
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
for ep in endpoints:
ip_port = ep.split(":")
with closing(
socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as sock:
sock.settimeout(2)
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
if result != 0:
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(
not_ready_endpoints) + "\n")
sys.stderr.flush()
time.sleep(3)
else:
break
#endpoints should be ["ip1:port1","ip2:port2"]
def initCommunicator(self, program, rank, nranks, wait_port,
current_endpoint, endpoints):
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
self.wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=nameGen.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': self.global_ring_id
})
def run_trainer(self, args):
train_prog = fluid.Program()
startup_prog = fluid.Program()
endpoints = args["endpoints"].split(",")
rank = args["trainerid"]
current_endpoint = args["currentendpoint"]
nranks = 2
self.initCommunicator(startup_prog, rank, nranks, True,
current_endpoint, endpoints)
result = self.get_model(train_prog, startup_prog)
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(
device_id) #if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
np.random.seed(os.getpid())
indata = np.random.random((10, 1000))
out = exe.run(train_prog,
feed={'tindata': indata},
fetch_list=[result.name])
if six.PY2:
print(pickle.dumps(out))
else:
sys.stdout.buffer.write(pickle.dumps(out))
def runtime_main(test_class, col_type, sub_type):
args = {}
model = test_class()
args["deviceid"] = os.getenv("FLAGS_selected_gpus")
args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID"))
args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM"))
args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS')
args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT")
args["col_type"] = col_type
model.run_trainer(args)
import paddle.compat as cpt
import socket
from contextlib import closing
class TestDistBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
self._trainers = 2
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def _run_cluster(self, model_file, envs):
worker_endpoints = self._ps_endpoints.split(",")
w0_ep, w1_ep = worker_endpoints
#print("w0_ep:",w0_ep," w1_ep:",w1_ep)
env0 = {
"FLAGS_selected_gpus": "2",
"PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w0_ep
}
env1 = {
"FLAGS_selected_gpus": "3",
"PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": w1_ep
}
#update environment
env0.update(envs)
env1.update(envs)
tr_cmd = "%s %s"
tr0_cmd = tr_cmd % (self._python_interp, model_file)
tr1_cmd = tr_cmd % (self._python_interp, model_file)
tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb")
#print(tr0_cmd)
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1)
tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate()
sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
return pickle.loads(tr0_out), pickle.loads(
tr1_out), tr0_proc.pid, tr1_proc.pid
def check_with_place(self,
model_file,
col_type,
check_error_log=False,
need_envs={}):
required_envs = {
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0",
"PATH": os.getenv("PATH"),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"GLOG_v": "0",
"NCCL_P2P_DISABLE": "1"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs)
np.random.seed(pid0)
input1 = np.random.random((10, 1000))
np.random.seed(pid1)
input2 = np.random.random((10, 1000))
if col_type == "allgather":
need_result = np.vstack((input1, input2))
self.assertTrue(np.allclose(tr0_out, need_result))
self.assertTrue(np.allclose(tr1_out, need_result))
elif col_type == "broadcast":
need_result = input2
self.assertTrue(np.allclose(tr0_out, need_result))
self.assertTrue(np.allclose(tr1_out, need_result))
elif col_type == "allreduce":
need_result = input1 + input2
self.assertTrue(
np.allclose(
tr0_out, need_result, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "reduce_scatter":
tmp = input1 + input2
need_result1 = tmp[0:tmp.shape[0] // 2]
need_result2 = tmp[tmp.shape[0] // 2:]
self.assertTrue(
np.allclose(
tr0_out, need_result1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result2, rtol=1e-05, atol=1e-05))
elif col_type == "reduce_slicegather":
slicesize = input1.shape[0] // 2
tmp10 = input1[0:slicesize]
tmp11 = input2[0:slicesize]
need_result1 = np.concatenate((tmp10, tmp11), axis=1)
tmp20 = input1[slicesize:]
tmp21 = input2[slicesize:]
need_result2 = np.concatenate((tmp20, tmp21), axis=1)
self.assertTrue(np.allclose(tr0_out, need_result1))
self.assertTrue(np.allclose(tr1_out, need_result2))
else:
pass
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from test_collective_base import TestDistBase
class TestReduceScatterOp(TestDistBase):
def _setup_config(self):
pass
def test_reducescatter(self):
self.check_with_place("collective_reducescatter_op.py",
"reduce_scatter")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
import math
from functools import reduce
import collections
import six
import logging
import numpy as np
from .. import core, unique_name
from ..framework import Program, default_main_program, default_startup_program
from .details import wait_server_ready
__all__ = ['GradAllReduce', 'LocalSGD']
OpRole = core.op_proto_and_checker_maker.OpRole
class Collective(object):
'''
'''
def __init__(self):
self.global_ring_id = 0
self.endpoints = None
self.current_endpoint = None
self.nranks = None
self.rank = None
self.startup_program = None
self.main_program = None
op_maker = core.op_proto_and_checker_maker
self.op_role_key = op_maker.kOpRoleAttrName()
self.op_role_var_key = op_maker.kOpRoleVarAttrName()
def transpile(self, startup_program, main_program, rank, endpoints,
current_endpoint, wait_port):
# in case of '127.0.0.1:6700,127.0.0.1:6701,...'
if isinstance(endpoints, str):
endpoints = endpoints.split(',')
self.startup_program = startup_program
if startup_program is None:
self.startup_program = default_startup_program()
self.main_program = main_program
if main_program is None:
self.main_program = default_main_program()
self.nranks = len(endpoints)
if self.nranks == 1:
raise ValueError('the number of endpoints must > 1')
if rank < 0:
raise ValueError('rank must >= 0')
self.rank = rank
if current_endpoint not in endpoints:
raise ValueError('current endpoint %s is not in %s',
current_endpoint, str(endpoints))
self.endpoints = endpoints
self.current_endpoint = current_endpoint
self.wait_port = wait_port
self.startup_program._origin_program = self.startup_program.clone()
self._transpile_startup_program()
self.main_program._origin_program = self.main_program.clone()
self._transpile_main_program()
def _transpile_main_program(self):
raise NotImplementedError('call the inherited method of subclasses')
def _transpile_startup_program(self):
self._init_communicator(self.startup_program, self.current_endpoint,
self.endpoints, self.rank, self.global_ring_id,
self.wait_port)
self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Collective
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Collective
})
def _broadcast_params(self):
block = self.startup_program.global_block()
for var in block.iter_parameters():
block.append_op(
type='c_broadcast',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': self.global_ring_id,
'root': 0,
self.op_role_key: OpRole.Collective
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': self.global_ring_id,
self.op_role_key: OpRole.Collective
})
def _is_loss_grad_op(self, op):
if self.op_role_key not in op.attr_names:
return False
op_role = int(op.all_attrs()[self.op_role_key])
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
def _is_backward_op(self, op):
return self.op_role_key in op.attr_names and \
int(op.all_attrs()[self.op_role_key]) & int(OpRole.Backward)
def _is_update_op(self, op):
return 'Param' in op.input_names and 'Grad' in op.input_names and \
"LearningRate" in op.input_names
def _is_optimizer_op(self, op):
return self.op_role_key in op.attr_names and \
int(op.all_attrs()[self.op_role_key]) & int(OpRole.Optimize)
class GradAllReduce(Collective):
'''
'''
def __init__(self):
Collective.__init__(self)
def _transpile_main_program(self):
self._insert_scale_loss_grad_ops()
self._insert_allreduce_ops()
def _insert_scale_loss_grad_ops(self):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
block = self.main_program.global_block()
for idx, op in reversed(list(enumerate(block.ops))):
if self._is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / self.nranks,
self.op_role_key: OpRole.Collective
})
def _insert_allreduce_ops(self):
block = self.main_program.global_block()
for idx, op in reversed(list(enumerate(block.ops))):
if self._is_backward_op(op) and \
self.op_role_var_key in op.attr_names:
op_role_var = op.all_attrs()[self.op_role_var_key]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
block._insert_op(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': block.vars[grad]},
outputs={'Out': block.vars[grad]},
attrs={self.op_role_key: OpRole.Collective})
offset = 2
for i in range(0, len(op_role_var), 2):
grad = op_role_var[i + 1]
block._insert_op(
idx + offset,
type='c_allreduce',
inputs={'X': [block.vars[grad]]},
outputs={'Out': [block.vars[grad]]},
attrs={
'reduce_type': 0,
self.op_role_key: OpRole.Collective
})
offset += 1
for idx, op in enumerate(block.ops):
if self._is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': block.vars[grad]},
outputs={'Out': block.vars[grad]},
attrs={
'ring_id': self.global_ring_id,
self.op_role_key: OpRole.Collective
})
break
class LocalSGD(Collective):
'''
'''
def __init__(self):
Collective.__init__(self)
self.snapshot_key = '@SNAPSHOT'
def _transpile_startup_program(self):
Collective._transpile_startup_program(self)
block = self.startup_program.global_block()
for param in block.iter_parameters():
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True)
block.append_op(
type='assign',
inputs={'X': [param]},
outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Collective})
def snapshot_name(self, param_name):
return param_name + self.snapshot_key
def _transpile_main_program(self):
block = self.main_program.global_block()
ordered_param_snapshot = []
for idx, op in reversed(list(enumerate(block.ops))):
if self._is_update_op(op):
param = block.vars[op.input('Param')[0]]
snapshot = block.create_var(
name=self.snapshot_name(param.name),
shape=param.shape,
persistable=True,
stop_gradient=True)
block._insert_op(
idx + 1,
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={self.op_role_key: OpRole.Collective})
block._insert_op(
idx + 2,
type='c_sync_calc_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={self.op_role_key: OpRole.Collective})
block._insert_op(
idx + 3,
type='c_allreduce',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'reduce_type': 0,
self.op_role_key: OpRole.Collective
})
ordered_param_snapshot.append((param, snapshot))
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.global_ring_id,
self.op_role_key: OpRole.Collective
})
for param_snapshot in reversed(ordered_param_snapshot):
param = param_snapshot[0]
snapshot = param_snapshot[1]
block.append_op(
type='scale',
inputs={'X': [param]},
outputs={'Out': [param]},
attrs={
'scale': 1.0 / self.nranks,
self.op_role_key: OpRole.Collective
})
block.append_op(
type='elementwise_sub',
inputs={'X': [snapshot],
'Y': [param]},
outputs={'Out': [param]},
attrs={self.op_role_key: OpRole.Collective})
block.append_op(
type='assign',
inputs={'X': [param]},
outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Collective})
......@@ -90,7 +90,7 @@ def variable_to_code(var):
return var_str
def op_to_code(op):
def op_to_code(op, skip_op_callstack=False):
"""
Get readable codes of fluid operator.
......@@ -124,6 +124,8 @@ def op_to_code(op):
attrs_str = ""
for i in range(0, len(attr_names)):
name = attr_names[i]
if skip_op_callstack and name == "op_callstack":
continue
attr_type = op.desc.attr_type(name)
if attr_type == core.AttrType.BLOCK:
......@@ -157,29 +159,35 @@ def op_to_code(op):
return op_str
def block_to_code(block, block_idx):
def block_to_code(block, block_idx, fout=None, skip_op_callstack=False):
indent = 0
print("{0}{1} // block {2}".format(
get_indent_space(indent), '{', block_idx))
print(
"{0}{1} // block {2}".format(get_indent_space(indent), '{', block_idx),
file=fout)
indent += 1
# sort all vars
all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
for var in all_vars:
print("{}{}".format(get_indent_space(indent), variable_to_code(var[1])))
print(
"{}{}".format(get_indent_space(indent), variable_to_code(var[1])),
file=fout)
if len(all_vars) > 0:
print("")
print("", file=fout)
for op in block.ops:
print("{}{}".format(get_indent_space(indent), op_to_code(op)))
print(
"{}{}".format(
get_indent_space(indent), op_to_code(op, skip_op_callstack)),
file=fout)
indent -= 1
print("{0}{1}".format(get_indent_space(indent), '}'))
print("{0}{1}".format(get_indent_space(indent), '}'), file=fout)
def program_to_code(prog):
def program_to_code(prog, fout=None, skip_op_callstack=False):
"""
Print readable codes of fluid program.
......@@ -191,5 +199,5 @@ def program_to_code(prog):
"""
block_idx = 0
for block in prog.blocks:
block_to_code(block, block_idx)
block_to_code(block, block_idx, fout, skip_op_callstack)
block_idx += 1
......@@ -47,6 +47,7 @@ from ..framework import Program, default_main_program, \
from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed
from .details import delete_ops, find_op_by_output_arg
from ..distribute_lookup_table import find_distributed_lookup_table
from . import collective
LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
......@@ -157,7 +158,7 @@ class DistributeTranspilerConfig(object):
split_method = None
min_block_size = 8192
enable_dc_asgd = False
# supported modes: pserver, nccl2
# supported modes: pserver, nccl2, collective
mode = "pserver"
print_log = False
wait_port = True
......@@ -174,6 +175,10 @@ class DistributeTranspilerConfig(object):
#Nccl ranks bewteen nodes when use hierarchical allreduce, it's setted to nodes number.
hierarchical_allreduce_exter_nranks = 0
# if mode is collective
# supported modes: sgd, local_sgd
collective_mode = None
class DistributeTranspiler(object):
"""
......@@ -305,6 +310,46 @@ class DistributeTranspiler(object):
else:
raise ValueError("must set trainer_id > 0")
def _transpile_collective(self,
collective_mode,
trainer_id,
trainers,
current_endpoint,
startup_program=None,
main_program=None,
wait_port=True):
if isinstance(trainers, str):
endpoints = trainers.split(",")
elif isinstance(trainers, list):
endpoints = trainers
else:
raise ValueError('invalid trainers config: ' + str(trainers))
if len(endpoints) == 1:
raise ValueError('invalid trainer number in distributed: 1')
if startup_program is None:
startup_program = default_startup_program()
if main_program is None:
main_program = default_main_program()
transpiler = None
if collective_mode == 'grad_allreduce':
transpiler = collective.GradAllReduce()
elif collective_mode == 'local_sgd':
transpiler = collective.LocalSGD()
else:
raise ValueError('invalid collective_mode: %s' % collective_mode)
transpiler.transpile(
startup_program=startup_program,
main_program=main_program,
rank=trainer_id,
endpoints=endpoints,
current_endpoint=current_endpoint,
wait_port=wait_port)
def _get_all_remote_sparse_update_op(self, main_program):
sparse_update_ops = []
sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
......@@ -395,6 +440,17 @@ class DistributeTranspiler(object):
wait_port=self.config.wait_port)
return
if self.config.mode == "collective":
self._transpile_collective(
collective_mode=self.config.collective_mode,
trainer_id=trainer_id,
trainers=trainers,
current_endpoint=current_endpoint,
startup_program=startup_program,
main_program=program,
wait_port=self.config.wait_port)
return
self.trainer_num = trainers
self.sync_mode = sync_mode
self.trainer_id = trainer_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册