diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h index 894859f6f0eac52a92d25d413eded1e6ccc6d625..f56b89d2ad87e88c2ef3e37e22dbd4ebab3afe0d 100644 --- a/paddle/operators/nccl/nccl_ops.h +++ b/paddle/operators/nccl/nccl_ops.h @@ -7,6 +7,8 @@ namespace paddle { namespace operators { +using framework::Tensor; + template class NCCLTypeWrapper; @@ -21,7 +23,7 @@ class NCCLTypeWrapper { }; template -class NCCLAllReduceKernel : public framework::OpKernel { +class NCCLAllReduceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); @@ -35,13 +37,14 @@ class NCCLAllReduceKernel : public framework::OpKernel { op_type = ncclProd; } else if (reduction == "ncclMin") { op_type = ncclMin; - } else - (reduction == "ncclMax") { op_type = ncclMax; } + } else if (reduction == "ncclMax") { + op_type = ncclMax; + } auto dev_ctx = static_cast(ctx.device_context()); - NCCLManager* m = NCCLManager::Get(); + platform::NCCLManager* m = platform::NCCLManager::Get(); auto* comm = m->GetCommunicator(gpus); comm->wg_.Add(1);