提交 ec47565c 编写于 作者: D Dong Zhihong

"add reduce hash function"

上级 423d7438
......@@ -289,6 +289,15 @@ class ExecutionContext {
return device_context_;
}
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get an output which has multiple variables.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
......
......@@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
" Input(Communicator) of Reduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......@@ -137,7 +134,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastSendOp
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllBcastSendOpMaker(framework::OpProto *proto,
NCCLBcastSendOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op");
......@@ -152,7 +149,7 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastOp
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NCCLAllBcastRecvOpMaker(framework::OpProto *proto,
NCCLBcastRecvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Communicator", "Communicator for communicating between gpus");
......
......@@ -2,8 +2,8 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
http://www.apache.org/licenseshashernless required by applicable law or agreed
to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
......@@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else if (reduction == "ncclMax") {
op_type = ncclMax;
} else {
PADDLE_ENFORCE(false, "reduction error.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
// device id
int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
......@@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, op_type,
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, ncclSum,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
......@@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<Tensor>("X");
auto ins = ctx.MultiInput<Tensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<Tensor>("Out");
auto* comm = ctx.Input<Communicator>("Communicator");
......@@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) {
int root = std::hash() % comm->comms_.size();
int root = hasher(ins_names[i]) % comm->comms_.size();
T* recvbuffer = nullptr;
if (root == device_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
}
PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, root, ncclSum,
NCCLTypeWrapper<T>::type, ncclSum, root,
comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
}
......@@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
} else {
auto outs = ctx.MultiOutput<Tensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) {
PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data<T>(),
PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册