diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index aca663ffc6fc6fe21b56821ce8e6c4616a4f69cf..09989c374c6ca3f4cebb7281fcb49b1a5f744ef5 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -125,7 +125,7 @@ class OperatorBase { protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs)opear + // I (Inputs) // O (Outputs) // OG (Output Gradients) VariableNameMap inputs_; diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 6213f2361321f287f363fc82a85ef759fa072940..ec7a89d5ff4a66e0527d8365027bc2c5b4e93aaf 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -9,26 +9,30 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/nccl_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" namespace paddle { namespace operators { // NCCLinitOp -class NCCLInitOp : public framework::OperatorWithKernel { +class NCCLInitOp : public framework::OperatorBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasOutput("Communicator"), - " Output(Communicator) of ncclInitOp should not be NULL"); - } - - protected: - framework::DataType IndicateDataType( - const framework::ExecutionContext &ctx) const override { - return static_cast(ctx.Attr("data_type")); + NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + const auto &name = Output("Communicator"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), + "Can not find variable '%s' in the scope.", name); + std::vector gpus = Attr>("gpus"); + PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); + platform::Communicator *comm = + scope.FindVar(name)->GetMutable(); + comm->InitAll(gpus); } }; @@ -188,13 +192,14 @@ class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { } // namespace paddle namespace ops = paddle::operators; +REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, + paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); + REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, ops::NCCLAllReduceOpMaker); -REGISTER_OP_WITHOUT_GRADIENT(ncclInit, ops::NCCLInitOp, ops::NCCLInitOpMaker); REGISTER_OP_WITHOUT_GRADIENT(ncclBcastSend, ops::NCCLBcastSendOp, ops::NCCLBcastSendOpMaker); REGISTER_OP_WITHOUT_GRADIENT(ncclBcastRecv, ops::NCCLBcastRecvOp, ops::NCCLBcastRecvOpMaker); REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp, ops::NCCLReduceOpMaker); -REGISTER_OP_CPU_KERNEL(ncclInit, ops::NCCLInitKernel); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index 00a115feeba3eae997c9e39d22cf3d1d9b27de8c..4fbdf1ce02dc9db4f430dc5e71998059a7435fc3 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -12,11 +12,30 @@ limitations under the License. */ #define EIGEN_USE_GPU #include -#include "paddle/operators/nccl_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" namespace paddle { namespace operators { +using framework::Tensor; +using platform::Communicator; + +template +class NCCLTypeWrapper; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclFloat; +}; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclDouble; +}; + template class NCCLAllReduceKernel : public framework::OpKernel { public: diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu index a25e01baa4d539673f4149c72048e53a5613ed04..334884d657acca3917a1c848edbd8769d2a9f15d 100644 --- a/paddle/operators/nccl_op_test.cu +++ b/paddle/operators/nccl_op_test.cu @@ -11,7 +11,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/nccl_op.h" #include #include @@ -65,11 +64,11 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, TEST(NCCL, ncclInitOp) { f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); - f::OpDescBind *op1 = block->AppendOp(); + f::OpDescBind *op_desc = block->AppendOp(); - op1->SetType("ncclInit"); - op1->SetOutput("Communicator", {"x1"}); - op1->SetAttr("gpus", {gpu_list}); + op_desc->SetType("ncclInit"); + op_desc->SetOutput("Communicator", {"x1"}); + op_desc->SetAttr("gpus", {gpu_list}); f::Scope g_scope; paddle::platform::DeviceContext *ctx = new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); @@ -77,7 +76,30 @@ TEST(NCCL, ncclInitOp) { auto *var = g_scope.Var("x1"); var->GetMutable(); - auto op = f::OpRegistry::CreateOp(*op1); + auto op = f::OpRegistry::CreateOp(*op_desc); + VLOG(1) << "invoke NCCLInitOp."; + op->Run(g_scope, *ctx); + VLOG(1) << "NCCLInitOp finished."; +} + +// ncclAllReduceOp with desc +TEST(NCCL, ncclInitOp) { + f::ProgramDescBind program; + f::BlockDescBind *block = program.Block(0); + f::OpDescBind *op_desc = block->AppendOp(); + + op_desc->SetType("ncclAllReduce"); + + op_desc->SetOutput("Communicator", {"x1"}); + op_desc->SetAttr("gpus", {gpu_list}); + f::Scope g_scope; + paddle::platform::DeviceContext *ctx = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + + auto *var = g_scope.Var("x1"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op_desc); VLOG(1) << "invoke NCCLInitOp."; op->Run(g_scope, *ctx); VLOG(1) << "NCCLInitOp finished.";