未验证 提交 a873fa84 编写于 作者: Y Yi Liu 提交者: GitHub

supports collective training with programs (#18392)

1. Since allreduce op has 4 reduce types, We split these four reduce types into four ops
2. We also refined the collective op code, e.g. we separated the collective op kernel into CPUKernel and CUDAKernel, and remove the device specified DeviceContext parameter in template as we already knew the target DeviceContext
3. We remove the newly added Collective op role to reduce the complexity of program and graph analysis
上级 85b49d84
...@@ -74,7 +74,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -74,7 +74,6 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize) | static_cast<int>(OpRole::kOptimize) |
static_cast<int>(OpRole::kLRSched), static_cast<int>(OpRole::kLRSched),
static_cast<int>(OpRole::kCollective),
static_cast<int>(OpRole::kNotSpecified)}) static_cast<int>(OpRole::kNotSpecified)})
.SetDefault(static_cast<int>(OpRole::kNotSpecified)); .SetDefault(static_cast<int>(OpRole::kNotSpecified));
AddAttr<std::vector<std::string>>(OpRoleVarAttrName(), AddAttr<std::vector<std::string>>(OpRoleVarAttrName(),
......
...@@ -34,9 +34,6 @@ enum class OpRole { ...@@ -34,9 +34,6 @@ enum class OpRole {
kDist = 0x0008, kDist = 0x0008,
// Tag all learning rate scheduler operators. // Tag all learning rate scheduler operators.
kLRSched = 0x0010, kLRSched = 0x0010,
// Collective role is for all collective operators and other operators used
// for collective training
kCollective = 0x0020,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include <future> // NOLINT
#include <memory> #include <memory>
#include <ostream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,8 +24,7 @@ class CAllGatherOp : public framework::OperatorWithKernel { ...@@ -25,8 +24,7 @@ class CAllGatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
"Output(Out) of SyncFCGather op should not be null.");
int nranks = ctx->Attrs().Get<int>("nranks"); int nranks = ctx->Attrs().Get<int>("nranks");
PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2"); PADDLE_ENFORCE_GE(nranks, 2, "nranks should be >=2");
framework::DDim dim = ctx->GetInputDim("X"); framework::DDim dim = ctx->GetInputDim("X");
...@@ -49,10 +47,10 @@ class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,10 +47,10 @@ class CAllGatherOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("nranks", AddAttr<int>("nranks",
"Total trainer count of the distributed training job"); "Total trainer count of the distributed training job");
AddComment(R"DOC( AddComment(R"DOC(
***CAllGather Operator*** CAllGather Operator
each rank receives the aggregation of data from all ranks in the order of the ranks each rank receives the aggregation of data from all ranks in the order of the ranks
Call NCCL collective AllGather internally.https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/api/colls.html#c.ncclAllGather reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allgather
)DOC"); )DOC");
} }
}; };
...@@ -81,9 +79,8 @@ namespace plat = paddle::platform; ...@@ -81,9 +79,8 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker, REGISTER_OPERATOR(c_allgather, ops::CAllGatherOp, ops::CAllGatherOpGradMaker,
ops::CAllGatherOpMaker); ops::CAllGatherOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(c_allgather, ops::CAllGatherOpCPUKernel<float>,
c_allgather, ops::CAllGatherOpKernel<plat::CPUDeviceContext, float>, ops::CAllGatherOpCPUKernel<double>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, double>, ops::CAllGatherOpCPUKernel<int>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int>, ops::CAllGatherOpCPUKernel<int64_t>,
ops::CAllGatherOpKernel<plat::CPUDeviceContext, int64_t>, ops::CAllGatherOpCPUKernel<plat::float16>);
ops::CAllGatherOpKernel<plat::CPUDeviceContext, plat::float16>);
...@@ -14,12 +14,64 @@ limitations under the License. */ ...@@ -14,12 +14,64 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_allgather_op.h" #include "paddle/fluid/operators/collective/c_allgather_op.h"
#include <memory>
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CAllGatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int nranks = ctx.Attr<int>("nranks");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
PADDLE_ENFORCE_EQ(nranks, comm->nranks());
auto place = ctx.GetPlace();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(c_allgather, ops::CAllGatherOpCUDAKernel<float>,
c_allgather, ops::CAllGatherOpKernel<plat::CUDADeviceContext, float>, ops::CAllGatherOpCUDAKernel<double>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, double>, ops::CAllGatherOpCUDAKernel<int>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int>, ops::CAllGatherOpCUDAKernel<int64_t>,
ops::CAllGatherOpKernel<plat::CUDADeviceContext, int64_t>, ops::CAllGatherOpCUDAKernel<plat::float16>);
ops::CAllGatherOpKernel<plat::CUDADeviceContext, plat::float16>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -22,52 +23,14 @@ limitations under the License. */ ...@@ -22,52 +23,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class CAllGatherOpKernel : public framework::OpKernel<T> { class CAllGatherOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); PADDLE_THROW("unimplemented cpu kernel for CAllGatherOp.");
PADDLE_ENFORCE(is_gpu_place(place),
"CAllGatherOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
framework::DDim out_dims = in->dims();
out_dims[0] *= nranks;
out->mutable_data<T>(out_dims, place);
int64_t send_numel = in->numel();
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclAllGather(
send_buff, recv_buff, send_numel, static_cast<ncclDataType_t>(dtype),
comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
} }
}; };
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
class CAllReduceMaxOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Max"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_max, ops::CAllReduceOp,
ops::CAllReduceMaxOpMaker);
REGISTER_OP_CPU_KERNEL(c_allreduce_max,
ops::CAllReduceOpCPUKernel<ops::kRedMax, float>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, double>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedMax, plat::float16>);
...@@ -18,8 +18,8 @@ namespace ops = paddle::operators; ...@@ -18,8 +18,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
c_allreduce, ops::CAllReduceOpKernel<plat::CUDADeviceContext, float>, c_allreduce_max, ops::CAllReduceOpCUDAKernel<ops::kRedMax, float>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, double>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, double>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, int>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, int64_t>, ops::CAllReduceOpCUDAKernel<ops::kRedMax, int64_t>,
ops::CAllReduceOpKernel<plat::CUDADeviceContext, plat::float16>); ops::CAllReduceOpCUDAKernel<ops::kRedMax, plat::float16>)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
class CAllReduceMinOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Min"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_min, ops::CAllReduceOp,
ops::CAllReduceMinOpMaker);
REGISTER_OP_CPU_KERNEL(c_allreduce_min,
ops::CAllReduceOpCPUKernel<ops::kRedMin, float>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, double>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedMin, plat::float16>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_min, ops::CAllReduceOpCUDAKernel<ops::kRedMin, float>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedMin, plat::float16>)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <utility> #include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -29,17 +28,41 @@ limitations under the License. */ ...@@ -29,17 +28,41 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
class CAllReduceOpKernel : public framework::OpKernel<T> {
class CAllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};
template <ReduceType red_type, typename T>
class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW("CAllReduce op do not support CPUKernel for now.");
}
};
template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel(); int64_t numel = in->numel();
const void* sendbuff = in->data<void>(); const void* sendbuff = in->data<void>();
...@@ -49,23 +72,6 @@ class CAllReduceOpKernel : public framework::OpKernel<T> { ...@@ -49,23 +72,6 @@ class CAllReduceOpKernel : public framework::OpKernel<T> {
int rid = ctx.Attr<int>("ring_id"); int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid); auto comm = platform::NCCLCommContext::Instance().Get(rid);
int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
cudaStream_t stream = nullptr; cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) { if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
...@@ -74,13 +80,60 @@ class CAllReduceOpKernel : public framework::OpKernel<T> { ...@@ -74,13 +80,60 @@ class CAllReduceOpKernel : public framework::OpKernel<T> {
stream = comm->stream(); stream = comm->stream();
} }
ncclRedOp_t nccl_red_type = ncclSum;
switch (red_type) {
case kRedSum:
nccl_red_type = ncclSum;
break;
case kRedMax:
nccl_red_type = ncclMax;
break;
case kRedMin:
nccl_red_type = ncclMin;
break;
case kRedProd:
nccl_red_type = ncclProd;
break;
default:
PADDLE_THROW("Invalid reduce type: %d", red_type);
}
PADDLE_ENFORCE(platform::dynload::ncclAllReduce( PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, dtype, red_type, comm->comm(), stream)); sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
#else #else
PADDLE_THROW("PaddlePaddle should compile with GPU."); PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif #endif
} }
}; };
class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the allreduced result.");
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(string::Sprintf(R"DOC(
CAllReduce %s Operator
Call collective AllReduce with reduce type %s. If input and output are
the same variable, in-place allreduce will be used.
Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce
)DOC",
GetName(), GetName()));
}
protected:
virtual std::string GetName() const = 0;
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
class CAllReduceProdOpMaker : public CAllReduceOpMaker {
protected:
std::string GetName() const override { return "Prod"; }
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_prod, ops::CAllReduceOp,
ops::CAllReduceProdOpMaker);
REGISTER_OP_CPU_KERNEL(c_allreduce_prod,
ops::CAllReduceOpCPUKernel<ops::kRedProd, float>,
ops::CAllReduceOpCPUKernel<ops::kRedProd, double>,
ops::CAllReduceOpCPUKernel<ops::kRedProd, int>,
ops::CAllReduceOpCPUKernel<ops::kRedProd, int64_t>,
ops::CAllReduceOpCPUKernel<ops::kRedProd, plat::float16>)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_prod, ops::CAllReduceOpCUDAKernel<ops::kRedProd, float>,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedProd, plat::float16>)
...@@ -12,58 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,58 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CAllReduceOp : public framework::OperatorWithKernel { class CAllReduceSumOpGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( std::unique_ptr<framework::OpDesc> Apply() const override {
const framework::ExecutionContext& ctx) const override { std::unique_ptr<framework::OpDesc> retv(new framework::OpDesc());
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), retv->SetType("c_allreduce_sum");
ctx.GetPlace()); retv->SetInput("X", OutputGrad("Out"));
retv->SetOutput("Out", InputGrad("X"));
retv->SetAttrMap(Attrs());
return retv;
} }
}; };
class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { class CAllReduceSumOpMaker : public CAllReduceOpMaker {
public: protected:
void Make() { std::string GetName() const override { return "Sum"; }
AddInput("X", "(Tensor), tensor to be allreduced.");
AddOutput("Out", "(Tensor) the allreduced result.");
AddAttr<int>("reduce_type", "(int default 0) determin the reduce type.")
.SetDefault(0);
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
.SetDefault(0);
AddAttr<bool>(
"use_calc_stream",
"(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false);
AddComment(R"DOC(
***CAllReduce Operator***
Call NCCL collective AllReduce internally. Note that this op must be used when one
thread is managing one GPU device.
For speed reasons, reduce_type should be an integer:
0: sum
1: prod
2: max
3: min
If input and output are the same variable, in-place allreduce will be used.
)DOC");
}
}; };
} // namespace operators } // namespace operators
...@@ -72,12 +43,12 @@ If input and output are the same variable, in-place allreduce will be used. ...@@ -72,12 +43,12 @@ If input and output are the same variable, in-place allreduce will be used.
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_allreduce, ops::CAllReduceOp, REGISTER_OPERATOR(c_allreduce_sum, ops::CAllReduceOp,
ops::CAllReduceOpMaker); ops::CAllReduceSumOpGradMaker, ops::CAllReduceSumOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(c_allreduce_sum,
c_allreduce, ops::CAllReduceOpKernel<plat::CPUDeviceContext, float>, ops::CAllReduceOpCPUKernel<ops::kRedSum, float>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, double>, ops::CAllReduceOpCPUKernel<ops::kRedSum, double>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int>, ops::CAllReduceOpCPUKernel<ops::kRedSum, int>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, int64_t>, ops::CAllReduceOpCPUKernel<ops::kRedSum, int64_t>,
ops::CAllReduceOpKernel<plat::CPUDeviceContext, plat::float16>); ops::CAllReduceOpCPUKernel<ops::kRedSum, plat::float16>)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
c_allreduce_sum, ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::float16>)
...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
namespace paddle { namespace paddle {
...@@ -50,9 +47,9 @@ class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,9 +47,9 @@ class CBroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false) eject CUDA operations to calculation stream.") "(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
***CBroadcast Operator*** CBroadcast Operator
Call ncclBcast internally. Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#broadcast
)DOC"); )DOC");
} }
}; };
...@@ -66,9 +63,8 @@ namespace plat = paddle::platform; ...@@ -66,9 +63,8 @@ namespace plat = paddle::platform;
REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp, REGISTER_OP_WITHOUT_GRADIENT(c_broadcast, ops::CBroadcastOp,
ops::CBroadcastOpMaker); ops::CBroadcastOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(c_broadcast, ops::CBroadcastOpCPUKernel<float>,
c_broadcast, ops::CBroadcastOpKernel<plat::CPUDeviceContext, float>, ops::CBroadcastOpCPUKernel<double>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, double>, ops::CBroadcastOpCPUKernel<int>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int>, ops::CBroadcastOpCPUKernel<int64_t>,
ops::CBroadcastOpKernel<plat::CPUDeviceContext, int64_t>, ops::CBroadcastOpCPUKernel<plat::float16>);
ops::CBroadcastOpKernel<plat::CPUDeviceContext, plat::float16>);
...@@ -14,12 +14,74 @@ limitations under the License. */ ...@@ -14,12 +14,74 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_broadcast_op.h" #include "paddle/fluid/operators/collective/c_broadcast_op.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CBroadcastOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
auto place = ctx.GetPlace();
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
if (root == comm->rank()) {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
framework::TensorCopy(
*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel, dtype, root,
comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel<float>,
c_broadcast, ops::CBroadcastOpKernel<plat::CUDADeviceContext, float>, ops::CBroadcastOpCUDAKernel<double>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, double>, ops::CBroadcastOpCUDAKernel<int>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int>, ops::CBroadcastOpCUDAKernel<int64_t>,
ops::CBroadcastOpKernel<plat::CUDADeviceContext, int64_t>, ops::CBroadcastOpCUDAKernel<plat::float16>);
ops::CBroadcastOpKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -22,69 +22,14 @@ limitations under the License. */ ...@@ -22,69 +22,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class CBroadcastOpKernel : public framework::OpKernel<T> { class CBroadcastOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); PADDLE_THROW("Unimplemented cpu kernel for CBroadcastOp.");
PADDLE_ENFORCE(is_gpu_place(place),
"CBroadcastOp can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto x = ctx.Input<framework::LoDTensor>("X");
auto out = ctx.Output<framework::LoDTensor>("Out");
int numel = x->numel();
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
int root = ctx.Attr<int>("root");
int nranks = comm->nranks();
PADDLE_ENFORCE(root >= 0 && root < nranks,
"Expected root in range of [0,%d),but get %d", nranks, root);
if (root == comm->rank()) {
PADDLE_ENFORCE(platform::dynload::ncclBcast(
reinterpret_cast<void*>(const_cast<T*>(x->data<T>())), numel, dtype,
root, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent "
<< x->numel();
if (out != x) {
// TODO(liuyi05): check inplace
framework::TensorCopy(
*static_cast<const framework::Tensor*>(x), place,
*platform::DeviceContextPool::Instance().Get(place),
static_cast<framework::Tensor*>(out));
}
} else {
PADDLE_ENFORCE(platform::dynload::ncclBcast(out->mutable_data<T>(place),
numel, dtype, root,
comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " invoke Bcast. recieved "
<< framework::product(out->dims());
}
out->Resize(x->dims());
out->set_lod(x->lod());
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
} }
}; };
......
...@@ -28,6 +28,7 @@ limitations under the License. */ ...@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <stdint.h> #include <stdint.h>
#include <ostream> #include <ostream>
#include <string> #include <string>
...@@ -24,9 +26,11 @@ limitations under the License. */ ...@@ -24,9 +26,11 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif #endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#include <future> // NOLINT
#include <memory> #include <memory>
#include <ostream>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -54,9 +53,9 @@ class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -54,9 +53,9 @@ class CReduceScatterOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false) eject CUDA operations to calculation stream.") "(bool default false) eject CUDA operations to calculation stream.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
***CReduceScatter Operator*** CReduceScatter Operator
Call NCCL collective ReduceScatter internally. Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#reducescatter
)DOC"); )DOC");
} }
}; };
...@@ -85,9 +84,8 @@ namespace plat = paddle::platform; ...@@ -85,9 +84,8 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp, REGISTER_OPERATOR(c_reducescatter, ops::CReduceScatterOp,
ops::CReduceScatterOpMaker); ops::CReduceScatterOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(c_reducescatter, ops::CReduceScatterOpCPUKernel<float>,
c_reducescatter, ops::CReduceScatterOpKernel<plat::CPUDeviceContext, float>, ops::CReduceScatterOpCPUKernel<double>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, double>, ops::CReduceScatterOpCPUKernel<int>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int>, ops::CReduceScatterOpCPUKernel<int64_t>,
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, int64_t>, ops::CReduceScatterOpCPUKernel<plat::float16>);
ops::CReduceScatterOpKernel<plat::CPUDeviceContext, plat::float16>);
...@@ -14,13 +14,61 @@ limitations under the License. */ ...@@ -14,13 +14,61 @@ limitations under the License. */
#include "paddle/fluid/operators/collective/c_reducescatter_op.h" #include "paddle/fluid/operators/collective/c_reducescatter_op.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename T>
class CReduceScatterOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
auto place = ctx.GetPlace();
auto out_dims = in->dims();
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks;
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
int dtype = platform::ToNCCLDataType(in->type());
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclReduceScatter(
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(c_reducescatter, ops::CReduceScatterOpCUDAKernel<float>,
c_reducescatter, ops::CReduceScatterOpCUDAKernel<double>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, float>, ops::CReduceScatterOpCUDAKernel<int>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, double>, ops::CReduceScatterOpCUDAKernel<int64_t>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int>, ops::CReduceScatterOpCUDAKernel<plat::float16>);
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, int64_t>,
ops::CReduceScatterOpKernel<plat::CUDADeviceContext, plat::float16>);
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -22,52 +23,14 @@ limitations under the License. */ ...@@ -22,52 +23,14 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T>
class CReduceScatterOpKernel : public framework::OpKernel<T> { class CReduceScatterOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); PADDLE_THROW("Unimplemented cpu kernel for CReduceScatterOp.");
PADDLE_ENFORCE(is_gpu_place(place),
"CAllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int rid = ctx.Attr<int>("ring_id");
auto comm = platform::NCCLCommContext::Instance().Get(rid);
int nranks = comm->nranks();
auto out_dims = in->dims();
out_dims[0] = out_dims[0] / nranks;
out->mutable_data<T>(out_dims, place);
int64_t recv_numel = in->numel() / nranks;
const T* send_buff = in->data<T>();
T* recv_buff = out->data<T>();
int dtype = platform::ToNCCLDataType(in->type());
cudaStream_t stream = nullptr;
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
} else {
stream = comm->stream();
}
PADDLE_ENFORCE(platform::dynload::ncclReduceScatter(
send_buff, recv_buff, recv_numel, static_cast<ncclDataType_t>(dtype),
ncclSum, comm->comm(), stream));
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
} }
}; };
......
...@@ -15,12 +15,12 @@ limitations under the License. */ ...@@ -15,12 +15,12 @@ limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <stdint.h>
#include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#endif #endif
...@@ -40,7 +40,6 @@ class CSyncCalcStreamOp : public framework::OperatorBase { ...@@ -40,7 +40,6 @@ class CSyncCalcStreamOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place), PADDLE_ENFORCE(is_gpu_place(place),
"Sync stream op can run on gpu place only for now."); "Sync stream op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto dev_ctx = static_cast<platform::CUDADeviceContext*>( auto dev_ctx = static_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
...@@ -57,12 +56,12 @@ class CSyncCalcStreamOp : public framework::OperatorBase { ...@@ -57,12 +56,12 @@ class CSyncCalcStreamOp : public framework::OperatorBase {
class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker { class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync"); AddInput("X", "(Tensor) Dependency of the variable need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync"); AddOutput("Out", "(Tensor) Dependency of the variable need to sync");
AddComment(R"DOC( AddComment(R"DOC(
***Sync Operator*** CSyncCalcStream Operator
Call cuda stream synchronize. Call calculation stream synchronization.
)DOC"); )DOC");
} }
}; };
......
...@@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,11 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <stdint.h>
#include <ostream>
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -57,13 +57,13 @@ class CSyncCommStreamOp : public framework::OperatorBase { ...@@ -57,13 +57,13 @@ class CSyncCommStreamOp : public framework::OperatorBase {
class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Tensor) Dependency of last param need to sync"); AddInput("X", "(Tensor) Dependency of the variable need to sync");
AddOutput("Out", "(Tensor) Dependency of last param need to sync"); AddOutput("Out", "(Tensor) Dependency of the variable need to sync");
AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0); AddAttr<int>("ring_id", "(int default 0) ring id.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
***Sync Operator*** CSyncCommStream Operator
Call nccl stream synchronize. Call communication stream synchronization.
)DOC"); )DOC");
} }
}; };
......
...@@ -12,11 +12,11 @@ ...@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// #ifndef _WIN32
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/collective_helper.h"
#include <functional> #include <memory>
#include <utility>
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
...@@ -34,24 +34,23 @@ class NCCLCommImpl : public NCCLComm { ...@@ -34,24 +34,23 @@ class NCCLCommImpl : public NCCLComm {
void set_rank(int rank) { rank_ = rank; } void set_rank(int rank) { rank_ = rank; }
int rank() const override { return rank_; } int rank() const override { return rank_; }
void set_local_rank(int local_rank) { local_rank_ = local_rank; } int device_id() const override {
int local_rank() const override { return local_rank_; } return boost::get<CUDAPlace>(dev_ctx_->GetPlace()).device;
}
void set_comm(ncclComm_t comm) { comm_ = comm; }
ncclComm_t comm() const override { return comm_; }
void set_dev_ctx(CUDADeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } ncclComm_t comm() const override { return dev_ctx_->nccl_comm(); }
CUDADeviceContext* DevCtx() const override { return dev_ctx_; }
cudaStream_t stream() const override { return dev_ctx_->stream(); } cudaStream_t stream() const override { return dev_ctx_->stream(); }
void set_dev_ctx(std::unique_ptr<CUDADeviceContext>&& dev_ctx) {
dev_ctx_ = std::move(dev_ctx);
}
private: private:
int ring_id_; int ring_id_;
int nranks_; int nranks_;
int rank_; int rank_;
int local_rank_; std::unique_ptr<CUDADeviceContext> dev_ctx_;
ncclComm_t comm_;
CUDADeviceContext* dev_ctx_;
}; };
// NOTE: not thread-safe // NOTE: not thread-safe
...@@ -73,13 +72,15 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, ...@@ -73,13 +72,15 @@ NCCLComm* NCCLCommContext::CreateNCCLComm(ncclUniqueId* nccl_id, int nranks,
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank)); platform::dynload::ncclCommInitRank(&comm, nranks, *nccl_id, rank));
std::unique_ptr<CUDADeviceContext> dev_ctx(
new CUDADeviceContext(CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
NCCLCommImpl* communicator = new NCCLCommImpl; NCCLCommImpl* communicator = new NCCLCommImpl;
communicator->set_ring_id(ring_id); communicator->set_ring_id(ring_id);
communicator->set_nranks(nranks); communicator->set_nranks(nranks);
communicator->set_rank(rank); communicator->set_rank(rank);
communicator->set_local_rank(dev_id); communicator->set_dev_ctx(std::move(dev_ctx));
communicator->set_comm(comm);
communicator->set_dev_ctx(dev_ctx_map_.at(dev_id).get());
comm_map_.emplace(ring_id, std::unique_ptr<NCCLComm>(communicator)); comm_map_.emplace(ring_id, std::unique_ptr<NCCLComm>(communicator));
......
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// #ifndef _WIN32
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#pragma once #pragma once
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -53,10 +52,9 @@ class NCCLComm { ...@@ -53,10 +52,9 @@ class NCCLComm {
virtual int ring_id() const = 0; virtual int ring_id() const = 0;
virtual int nranks() const = 0; virtual int nranks() const = 0;
virtual int rank() const = 0; virtual int rank() const = 0;
virtual int local_rank() const = 0; virtual int device_id() const = 0;
virtual ncclComm_t comm() const = 0; virtual ncclComm_t comm() const = 0;
virtual cudaStream_t stream() const = 0; virtual cudaStream_t stream() const = 0;
virtual CUDADeviceContext* DevCtx() const = 0;
virtual ~NCCLComm() = default; virtual ~NCCLComm() = default;
}; };
...@@ -73,16 +71,6 @@ class NCCLCommContext { ...@@ -73,16 +71,6 @@ class NCCLCommContext {
NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank, NCCLComm* CreateNCCLComm(ncclUniqueId* nccl_id, int nranks, int rank,
int dev_id, int ring_id = 0); int dev_id, int ring_id = 0);
CUDADeviceContext* DevCtx(int dev_id) const {
PADDLE_ENFORCE(dev_ctx_map_.count(dev_id),
"CUDADeviceContext at device %d has not been initialized");
return dev_ctx_map_.at(dev_id).get();
}
CUDADeviceContext* DevCtx(platform::Place p) const {
return DevCtx(boost::get<CUDAPlace>(p).device);
}
// retrieve a communicator by the ring id // retrieve a communicator by the ring id
NCCLComm* Get(int ring_id) const { NCCLComm* Get(int ring_id) const {
PADDLE_ENFORCE(comm_map_.count(ring_id), PADDLE_ENFORCE(comm_map_.count(ring_id),
......
...@@ -46,7 +46,6 @@ void BindConstValue(pybind11::module* m) { ...@@ -46,7 +46,6 @@ void BindConstValue(pybind11::module* m) {
.value("Loss", framework::OpRole::kLoss) .value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC) .value("RPC", framework::OpRole::kRPC)
.value("Dist", framework::OpRole::kDist) .value("Dist", framework::OpRole::kDist)
.value("Collective", framework::OpRole::kCollective)
.value("LRSched", framework::OpRole::kLRSched); .value("LRSched", framework::OpRole::kLRSched);
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
......
...@@ -42,7 +42,6 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase): ...@@ -42,7 +42,6 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase):
def get_model(self, main_prog, startup_program): def get_model(self, main_prog, startup_program):
ring_id = 0 ring_id = 0
reduce_type = 0
with fluid.program_guard(main_prog, startup_program): with fluid.program_guard(main_prog, startup_program):
tindata = layers.data( tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float32') name="tindata", shape=[10, 1000], dtype='float32')
...@@ -53,10 +52,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase): ...@@ -53,10 +52,9 @@ class TestCollectiveAllreduce(TestCollectiveRunnerBase):
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False)
main_prog.global_block().append_op( main_prog.global_block().append_op(
type="c_allreduce", type="c_allreduce_sum",
inputs={'X': tindata}, inputs={'X': tindata},
attrs={'ring_id': ring_id, attrs={'ring_id': ring_id},
'reduce_type': reduce_type},
outputs={'Out': toutdata}) outputs={'Out': toutdata})
main_prog.global_block().append_op( main_prog.global_block().append_op(
type="c_sync_comm_stream", type="c_sync_comm_stream",
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the 'License'); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -37,8 +37,8 @@ class Collective(object): ...@@ -37,8 +37,8 @@ class Collective(object):
''' '''
''' '''
def __init__(self): def __init__(self, nrings):
self.global_ring_id = 0 self.nrings = nrings
self.endpoints = None self.endpoints = None
self.current_endpoint = None self.current_endpoint = None
self.nranks = None self.nranks = None
...@@ -90,9 +90,10 @@ class Collective(object): ...@@ -90,9 +90,10 @@ class Collective(object):
raise NotImplementedError('call the inherited method of subclasses') raise NotImplementedError('call the inherited method of subclasses')
def _transpile_startup_program(self): def _transpile_startup_program(self):
self._init_communicator(self.startup_program, self.current_endpoint, for ring_id in range(self.nrings):
self.endpoints, self.rank, self.global_ring_id, self._init_communicator(self.startup_program, self.current_endpoint,
self.wait_port) self.endpoints, self.rank, ring_id,
self.wait_port)
self._broadcast_params() self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank, def _init_communicator(self, program, current_endpoint, endpoints, rank,
...@@ -116,7 +117,7 @@ class Collective(object): ...@@ -116,7 +117,7 @@ class Collective(object):
'rank': rank, 'rank': rank,
'endpoint': current_endpoint, 'endpoint': current_endpoint,
'other_endpoints': other_endpoints, 'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Forward
}) })
block.append_op( block.append_op(
type='c_comm_init', type='c_comm_init',
...@@ -126,29 +127,31 @@ class Collective(object): ...@@ -126,29 +127,31 @@ class Collective(object):
'nranks': nranks, 'nranks': nranks,
'rank': rank, 'rank': rank,
'ring_id': ring_id, 'ring_id': ring_id,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Forward
}) })
def _broadcast_params(self): def _broadcast_params(self):
block = self.startup_program.global_block() block = self.startup_program.global_block()
for var in block.iter_parameters(): ring_id = -1
for param in block.iter_parameters():
ring_id = (ring_id + 1) % self.nrings
block.append_op( block.append_op(
type='c_broadcast', type='c_broadcast',
inputs={'X': var}, inputs={'X': param},
outputs={'Out': var}, outputs={'Out': param},
attrs={ attrs={
'ring_id': self.global_ring_id, 'ring_id': ring_id,
'root': 0, 'root': 0,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Forward
}) })
block.append_op(
type='c_sync_comm_stream', for ring_id in range(self.nrings):
inputs={'X': var}, block.append_op(
outputs={'Out': var}, type='c_sync_comm_stream',
attrs={ inputs={'X': param},
'ring_id': self.global_ring_id, outputs={'Out': param},
self.op_role_key: OpRole.Collective attrs={'ring_id': ring_id,
}) self.op_role_key: OpRole.Forward})
def _is_loss_grad_op(self, op): def _is_loss_grad_op(self, op):
if self.op_role_key not in op.attr_names: if self.op_role_key not in op.attr_names:
...@@ -173,8 +176,8 @@ class GradAllReduce(Collective): ...@@ -173,8 +176,8 @@ class GradAllReduce(Collective):
''' '''
''' '''
def __init__(self): def __init__(self, nrings=2):
Collective.__init__(self) Collective.__init__(self, nrings)
def _transpile_main_program(self): def _transpile_main_program(self):
self._insert_scale_loss_grad_ops() self._insert_scale_loss_grad_ops()
...@@ -196,11 +199,13 @@ class GradAllReduce(Collective): ...@@ -196,11 +199,13 @@ class GradAllReduce(Collective):
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={ attrs={
'scale': 1.0 / self.nranks, 'scale': 1.0 / self.nranks,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Backward
}) })
def _insert_allreduce_ops(self): def _insert_allreduce_ops(self):
block = self.main_program.global_block() block = self.main_program.global_block()
ring_id = -1
grad = None
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if self._is_backward_op(op) and \ if self._is_backward_op(op) and \
self.op_role_var_key in op.attr_names: self.op_role_var_key in op.attr_names:
...@@ -208,41 +213,50 @@ class GradAllReduce(Collective): ...@@ -208,41 +213,50 @@ class GradAllReduce(Collective):
if len(op_role_var) == 0: if len(op_role_var) == 0:
continue continue
assert len(op_role_var) % 2 == 0 assert len(op_role_var) % 2 == 0
block._insert_op( offset = idx
idx + 1,
type='c_sync_calc_stream',
inputs={'X': block.vars[grad]},
outputs={'Out': block.vars[grad]},
attrs={self.op_role_key: OpRole.Collective})
offset = 2
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
grad = op_role_var[i + 1] param = block.vars[op_role_var[i]]
grad = block.vars[op_role_var[i + 1]]
if offset == idx:
offset += 1
block._insert_op(
offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={self.op_role_key: OpRole.Backward})
offset += 1
# As we search ops reversedly, we should insert c_allreduce_sum
# op in the same way to keep the ring_id alternate
ring_id = (ring_id + 1) % self.nrings
block._insert_op( block._insert_op(
idx + offset, offset,
type='c_allreduce', type='c_allreduce_sum',
inputs={'X': [block.vars[grad]]}, inputs={'X': grad},
outputs={'Out': [block.vars[grad]]}, outputs={'Out': grad},
attrs={ attrs={
'reduce_type': 0, 'ring_id': ring_id,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Backward
}) })
offset += 1
if grad is None:
return
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if self._is_optimizer_op(op): if self._is_optimizer_op(op):
block._insert_op( for ring_id in range(self.nrings):
idx, block._insert_op(
type='c_sync_comm_stream', idx + ring_id,
inputs={'X': block.vars[grad]}, type='c_sync_comm_stream',
outputs={'Out': block.vars[grad]}, inputs={'X': grad},
attrs={ outputs={'Out': grad},
'ring_id': self.global_ring_id, attrs={
self.op_role_key: OpRole.Collective 'ring_id': ring_id,
}) self.op_role_key: OpRole.Backward
})
break break
...@@ -250,8 +264,8 @@ class LocalSGD(Collective): ...@@ -250,8 +264,8 @@ class LocalSGD(Collective):
''' '''
''' '''
def __init__(self): def __init__(self, nrings=2):
Collective.__init__(self) Collective.__init__(self, nrings)
self.snapshot_key = '@SNAPSHOT' self.snapshot_key = '@SNAPSHOT'
def _transpile_startup_program(self): def _transpile_startup_program(self):
...@@ -268,7 +282,7 @@ class LocalSGD(Collective): ...@@ -268,7 +282,7 @@ class LocalSGD(Collective):
type='assign', type='assign',
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [snapshot]}, outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Collective}) attrs={self.op_role_key: OpRole.Forward})
def snapshot_name(self, param_name): def snapshot_name(self, param_name):
return param_name + self.snapshot_key return param_name + self.snapshot_key
...@@ -276,6 +290,7 @@ class LocalSGD(Collective): ...@@ -276,6 +290,7 @@ class LocalSGD(Collective):
def _transpile_main_program(self): def _transpile_main_program(self):
block = self.main_program.global_block() block = self.main_program.global_block()
ordered_param_snapshot = [] ordered_param_snapshot = []
ring_id = -1
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if self._is_update_op(op): if self._is_update_op(op):
param = block.vars[op.input('Param')[0]] param = block.vars[op.input('Param')[0]]
...@@ -291,33 +306,33 @@ class LocalSGD(Collective): ...@@ -291,33 +306,33 @@ class LocalSGD(Collective):
inputs={'X': [snapshot], inputs={'X': [snapshot],
'Y': [param]}, 'Y': [param]},
outputs={'Out': [param]}, outputs={'Out': [param]},
attrs={self.op_role_key: OpRole.Collective}) attrs={self.op_role_key: OpRole.Optimize})
block._insert_op( block._insert_op(
idx + 2, idx + 2,
type='c_sync_calc_stream', type='c_sync_calc_stream',
inputs={'X': param}, inputs={'X': param},
outputs={'Out': param}, outputs={'Out': param},
attrs={self.op_role_key: OpRole.Collective}) attrs={self.op_role_key: OpRole.Optimize})
ring_id = (ring_id + 1) % self.nrings
block._insert_op( block._insert_op(
idx + 3, idx + 3,
type='c_allreduce', type='c_allreduce_sum',
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [param]}, outputs={'Out': [param]},
attrs={ attrs={
'reduce_type': 0, 'ring_id': ring_id,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Optimize
}) })
ordered_param_snapshot.append((param, snapshot)) ordered_param_snapshot.append((param, snapshot))
block.append_op( for ring_id in range(self.nrings):
type='c_sync_comm_stream', block.append_op(
inputs={'X': param}, type='c_sync_comm_stream',
outputs={'Out': param}, inputs={'X': param},
attrs={ outputs={'Out': param},
'ring_id': self.global_ring_id, attrs={'ring_id': ring_id,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Optimize})
})
for param_snapshot in reversed(ordered_param_snapshot): for param_snapshot in reversed(ordered_param_snapshot):
param = param_snapshot[0] param = param_snapshot[0]
...@@ -328,16 +343,16 @@ class LocalSGD(Collective): ...@@ -328,16 +343,16 @@ class LocalSGD(Collective):
outputs={'Out': [param]}, outputs={'Out': [param]},
attrs={ attrs={
'scale': 1.0 / self.nranks, 'scale': 1.0 / self.nranks,
self.op_role_key: OpRole.Collective self.op_role_key: OpRole.Optimize
}) })
block.append_op( block.append_op(
type='elementwise_sub', type='elementwise_sub',
inputs={'X': [snapshot], inputs={'X': [snapshot],
'Y': [param]}, 'Y': [param]},
outputs={'Out': [param]}, outputs={'Out': [param]},
attrs={self.op_role_key: OpRole.Collective}) attrs={self.op_role_key: OpRole.Optimize})
block.append_op( block.append_op(
type='assign', type='assign',
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [snapshot]}, outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Collective}) attrs={self.op_role_key: OpRole.Optimize})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册