提交 ec47565c 编写于 作者: D Dong Zhihong

"add reduce hash function"

上级 423d7438
6 合并请求!11636[IMPORTANT] MKLDNN layout: Support for sum operator,!8482Release/0.11.0,!8190Release/0.11.0,!8189Release/0.11.0,!6633给线性回归的get-started代码加上了预测的示例~~,!4838Feature/multigpu
...@@ -289,6 +289,15 @@ class ExecutionContext { ...@@ -289,6 +289,15 @@ class ExecutionContext {
return device_context_; return device_context_;
} }
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get an output which has multiple variables.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const { const platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
......
...@@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { ...@@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
" Input(Communicator) of Reduce op input should not be NULL"); " Input(Communicator) of Reduce op input should not be NULL");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
" Input(X) of Reduce op input should not be NULL"); " Input(X) of Reduce op input should not be NULL");
ctx->SetOutputsDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} }
}; };
...@@ -137,7 +134,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -137,7 +134,7 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastSendOp // BcastSendOp
class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLAllBcastSendOpMaker(framework::OpProto *proto, NCCLBcastSendOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of BcastSend op"); AddInput("X", "The input of BcastSend op");
...@@ -152,7 +149,7 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -152,7 +149,7 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// BcastOp // BcastOp
class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NCCLAllBcastRecvOpMaker(framework::OpProto *proto, NCCLBcastRecvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Communicator", "Communicator for communicating between gpus"); AddInput("Communicator", "Communicator for communicating between gpus");
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenseshashernless required by applicable law or agreed
Unless required by applicable law or agreed to in writing, software to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
...@@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto outs = ctx.MultiOutput<Tensor>("Out"); auto outs = ctx.MultiOutput<Tensor>("Out");
std::string reduction = ctx.Attr<std::string>("reduction"); std::string reduction = ctx.Attr<std::string>("reduction");
ncclRedOp_t op_type;
if (reduction == "ncclSum") {
op_type = ncclSum;
} else if (reduction == "ncclProd") {
op_type = ncclProd;
} else if (reduction == "ncclMin") {
op_type = ncclMin;
} else if (reduction == "ncclMax") {
op_type = ncclMax;
} else {
PADDLE_ENFORCE(false, "reduction error.");
}
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context()) ctx.device_context())
.stream(); .stream();
// device id // device id
int device_id = int device_id =
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId(); boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
...@@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
PADDLE_ENFORCE(ncclAllReduce( PADDLE_ENFORCE(ncclAllReduce(
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()), ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, op_type, outs[i]->numel() * sizeof(T), NCCLTypeWrapper<T>::type, ncclSum,
comm->comms_[idx], stream)); comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
} }
...@@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device."); "This kernel only runs on GPU device.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X"); // x0, x1, x2
auto outs = ctx.MultiOutput<Tensor>("Out"); auto outs = ctx.MultiOutput<Tensor>("Out");
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
...@@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId(); boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(device_id); int idx = comm->GetCommId(device_id);
auto ins_names = ctx.Inputs("X");
std::hash<std::string> hasher;
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
int root = std::hash() % comm->comms_.size(); int root = hasher(ins_names[i]) % comm->comms_.size();
T* recvbuffer = nullptr; T* recvbuffer = nullptr;
if (root == device_id) { if (root == device_id) {
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace()); recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
} }
PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(), PADDLE_ENFORCE(ncclReduce(ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
NCCLTypeWrapper<T>::type, root, ncclSum, NCCLTypeWrapper<T>::type, ncclSum, root,
comm->comms_[idx], stream)); comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
} }
...@@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> { ...@@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
} else { } else {
auto outs = ctx.MultiOutput<Tensor>("Out"); auto outs = ctx.MultiOutput<Tensor>("Out");
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data<T>(), PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data<T>(ctx.GetPlace()),
outs[i]->numel(), NCCLTypeWrapper<T>::type, outs[i]->numel(), NCCLTypeWrapper<T>::type,
root, comm->comms_[idx], stream)); root, comm->comms_[idx], stream));
PADDLE_ENFORCE(cudaStreamSynchronize(stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部