c_allreduce_op.h 20.4 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17

#include <string>
18

19
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
20 21 22
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
23 24
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
25
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
26
#include "paddle/phi/api/include/tensor.h"
27

Z
zn 已提交
28 29 30
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ||          \
    defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU_BKCL) || \
    defined(PADDLE_WITH_CNCL)
31
#include "paddle/fluid/platform/collective_helper.h"
32 33 34
#endif

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
35
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
36 37
#endif

38
#if defined(PADDLE_WITH_XPU_BKCL)
39
#include "paddle/fluid/platform/device/xpu/bkcl_helper.h"
40 41
#endif

42 43
#if defined(PADDLE_WITH_GLOO)
#include <gloo/allreduce.h>
44

45 46 47
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

48
#if defined(PADDLE_WITH_ASCEND_CL)
49
#include "paddle/fluid/platform/device/npu/hccl_helper.h"
50 51
#endif

Z
zn 已提交
52 53 54 55
#if defined(PADDLE_WITH_CNCL)
#include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#endif

56 57 58 59
#if defined(PADDLE_WITH_ASCEND_CL)
DECLARE_bool(hccl_check_nan);
#endif

60 61 62
namespace paddle {
namespace operators {

63 64 65 66 67 68 69 70 71 72 73 74 75
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };

class CAllReduceOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
76 77
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
78
  }
79 80 81

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name,
82
      const phi::DenseTensor& tensor,
83 84 85 86 87 88 89 90
      const framework::OpKernelType& expected_kernel_type) const {
    if (var_name == "Cond") {
      return expected_kernel_type;
    } else {
      return framework::OpKernelType(
          expected_kernel_type.data_type_, tensor.place(), tensor.layout());
    }
  }
91 92 93 94 95 96
};

template <ReduceType red_type, typename T>
class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
97
#if defined(PADDLE_WITH_GLOO)
98 99
    auto in = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
100 101 102 103 104 105 106

    auto place = ctx.GetPlace();
    int64_t send_numel = in->numel();
    const T* send_buff = in->data<T>();
    T* recv_buff = out->mutable_data<T>(in->dims(), place);
    auto gloo = paddle::framework::GlooWrapper::GetInstance();
    PADDLE_ENFORCE_EQ(
107 108
        gloo->IsInitialized(),
        true,
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        platform::errors::PreconditionNotMet(
            "You must initialize the gloo environment first to use it."));
    gloo::AllreduceOptions opts(gloo->GetContext());
    opts.setInput(const_cast<T*>(send_buff), send_numel);
    opts.setOutput(recv_buff, send_numel);
    switch (red_type) {
      case kRedSum:
        opts.setReduceFunction(
            static_cast<void (*)(void*, const void*, const void*, size_t)>(
                &gloo::sum<T>));
        break;
      case kRedMax:
        opts.setReduceFunction(
            static_cast<void (*)(void*, const void*, const void*, size_t)>(
                &gloo::max<T>));
        break;
      case kRedMin:
        opts.setReduceFunction(
            static_cast<void (*)(void*, const void*, const void*, size_t)>(
                &gloo::min<T>));
        break;
      case kRedProd:
        opts.setReduceFunction(
            static_cast<void (*)(void*, const void*, const void*, size_t)>(
                &gloo::product<T>));
        break;
      default:
136 137
        PADDLE_ENFORCE_EQ(true,
                          false,
138 139 140 141 142 143 144 145
                          platform::errors::InvalidArgument(
                              "Invalid reduce type: %d.", red_type));
    }
    gloo::allreduce(opts);
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON"));
#endif
146 147 148
  }
};

149
#if defined(PADDLE_WITH_ASCEND_CL)
150 151 152
// return true if found_nan or return false;
inline bool ContainsNan(const paddle::platform::NPUDeviceContext& dev_ctx,
                        aclrtStream stream,
153 154
                        const phi::DenseTensor* in) {
  using Tensor = phi::DenseTensor;
155
  Tensor out(in->type());
156

157 158 159 160 161 162 163
  Tensor mean(in->type());
  mean.Resize({1});
  mean.mutable_data<float>(dev_ctx.GetPlace());
  std::vector<int> axes;
  for (int i = 0; i < in->dims().size(); ++i) {
    axes.push_back(i);
  }
164

165
  std::vector<float> vec;
166
  try {
167 168
    const auto& runner_mean = paddle::operators::NpuOpRunner(
        "ReduceMeanD", {*in}, {mean}, {{"axes", axes}, {"keep_dims", false}});
169
    paddle::framework::TensorToVector(mean, dev_ctx, &vec);
170
  } catch (...) {
171 172 173 174 175 176 177 178 179 180 181 182
    LOG(WARNING) << "ContainsNan catch exception";
    return true;
  }

  VLOG(4) << "reducemeand result:" << vec[0];
  if (std::isnan(static_cast<float>(vec[0]))) {
    LOG(WARNING) << "ContainsNan detects nan";
    return true;
  }

  if (std::isinf(static_cast<float>(vec[0]))) {
    LOG(WARNING) << "ContainsNan detects inf";
183 184
  }

185
  return false;
186
}
187

188 189
#endif

190 191 192 193 194
template <ReduceType red_type, typename T>
class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
195
    if (ctx.HasInput("Cond")) {
196
      auto cond = ctx.Input<phi::DenseTensor>("Cond");
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
      auto place = cond->place();
      PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
                        true,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` tensor should be on cpu place"));
      PADDLE_ENFORCE_EQ(cond->numel(),
                        1,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` should be shape [1]"));
      if (!cond->data<bool>()[0]) {
        VLOG(4) << "Skip all reduce Op since cond is 0";
        return;
      }
    }

212 213 214
    auto in = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");

215
    auto place = ctx.GetPlace();
216 217
    HcclDataType dtype =
        platform::ToHCCLDataType(framework::TransToProtoVarType(in->dtype()));
218 219 220
    int64_t numel = in->numel();

    void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
221
    out->mutable_data<T>(in->dims(), ctx.GetPlace());
222 223 224 225 226 227 228 229 230
    void* recvbuff = reinterpret_cast<void*>(out->data<T>());

    int ring_id = ctx.Attr<int>("ring_id");
    std::string group =
        std::string(HCOM_GROUP_PREFIX) + std::to_string(ring_id);
    auto comm =
        paddle::platform::HCCLCommContext::Instance().Get(ring_id, place);

    aclrtStream stream = nullptr;
231 232
    auto dev_ctx = static_cast<platform::NPUDeviceContext*>(
        platform::DeviceContextPool::Instance().Get(place));
233
    if (ctx.Attr<bool>("use_calc_stream")) {
234
      stream = dev_ctx->stream();
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
    } else {
      stream = comm->stream();
    }

    HcclReduceOp hccl_red_type = HCCL_REDUCE_SUM;
    switch (red_type) {
      case kRedSum:
        hccl_red_type = HCCL_REDUCE_SUM;
        break;

      case kRedMax:
        hccl_red_type = HCCL_REDUCE_MAX;
        break;

      case kRedMin:
        hccl_red_type = HCCL_REDUCE_MIN;
        break;

      case kRedProd:
        hccl_red_type = HCCL_REDUCE_PROD;
        break;

      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid reduce type: %d", red_type));
    }

262 263 264 265 266 267 268 269
    VLOG(3) << "hccl allreduce, parameter is: "
            << "input num: " << in->dims() << "dtype: " << dtype
            << "hccl_red_type: " << hccl_red_type << ", group is: " << group
            << ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
            << ", out_size:" << out->memory_size()
            << ", use_calc_stream:" << ctx.Attr<bool>("use_calc_stream")
            << ", stream:" << stream;

270
    phi::DenseTensor tmp;
271 272
    tmp.mutable_data<float>({8}, ctx.GetPlace());

273
    bool found_nan = false;
274

275
    auto d_type = framework::TransToProtoVarType(in->dtype());
276
    switch (d_type) {
277 278 279
      case framework::proto::VarType::FP16: {
        break;
      }
280
      case framework::proto::VarType::FP32: {
281 282
        if (FLAGS_hccl_check_nan) {
          VLOG(3) << "prepare to FoundNanInf";
Y
Yuang Liu 已提交
283 284
          // NOTE: performance relating, DO NOT REMOVE!
          ContainsNan(*dev_ctx, dev_ctx->stream(), in);
285
        }
286 287 288 289 290 291
        break;
      }
      default:
        break;
    }

292
    if (found_nan) {
293 294 295
      T inf = static_cast<T>(std::numeric_limits<float>::infinity());
      VLOG(4) << "fill input data constant inf";
      auto dims = in->dims();
296
      auto mutable_in = const_cast<phi::DenseTensor*>(in);
297 298 299 300 301
      FillNpuTensorWithConstant<T>(mutable_in, inf);
      mutable_in->Resize(dims);
    }

    VLOG(3) << "hccl allreduce, parameter is: "
302
            << "input num: " << numel << "dtype: " << dtype
303 304 305
            << "hccl_red_type: " << hccl_red_type << ", group is: " << group
            << ", sendbuff:" << sendbuff << ", recvbuff:" << recvbuff
            << ", out_size:" << out->memory_size();
306

307 308 309 310 311 312 313 314
    PADDLE_ENFORCE_NPU_SUCCESS(
        platform::dynload::HcclAllReduce(sendbuff,
                                         recvbuff,
                                         numel,
                                         dtype,
                                         hccl_red_type,
                                         comm->comm(),
                                         reinterpret_cast<void*>(stream)));
315 316 317 318 319 320 321 322 323

    out->Resize(in->dims());
#else
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with NPU."));
#endif
  }
};

324 325 326 327 328
template <ReduceType red_type, typename T>
class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_XPU_BKCL)
329
    if (ctx.HasInput("Cond")) {
330
      auto cond = ctx.Input<phi::DenseTensor>("Cond");
331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
      auto place = cond->place();
      PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
                        true,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` tensor should be on cpu place"));
      PADDLE_ENFORCE_EQ(cond->numel(),
                        1,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` should be shape [1]"));
      if (!cond->data<bool>()[0]) {
        VLOG(4) << "Skip all reduce Op since cond is 0";
        return;
      }
    }

346 347
    auto in = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
348 349

    auto place = ctx.GetPlace();
350 351
    BKCLDataType dtype =
        platform::ToBKCLDataType(framework::TransToProtoVarType(in->dtype()));
352
    int64_t numel = in->numel();
353
    const void* sendbuff = in->data<T>();
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392
    out->Resize(in->dims());
    void* recvbuff = out->mutable_data<T>(place);

    int rid = ctx.Attr<int>("ring_id");
    auto comm = platform::BKCLCommContext::Instance().Get(rid, place);

    XPUStream stream = nullptr;
    if (ctx.Attr<bool>("use_calc_stream")) {
      auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
      stream = static_cast<platform::XPUDeviceContext*>(dev_ctx)
                   ->x_context()
                   ->xpu_stream;
    } else {
      stream = comm->stream();
    }

    BKCLOp bkcl_red_type = BKCL_ADD;
    switch (red_type) {
      case kRedSum:
        bkcl_red_type = BKCL_ADD;
        break;

      case kRedMax:
        bkcl_red_type = BKCL_MAX;
        break;

      case kRedMin:
        bkcl_red_type = BKCL_MIN;
        break;

      case kRedProd:
        bkcl_red_type = BKCL_PRODUCT;
        break;

      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid reduce type: %d", red_type));
    }

393
    PADDLE_ENFORCE_EQ(
394 395 396 397 398 399 400
        bkcl_all_reduce(comm->comm(),
                        sendbuff,
                        recvbuff,
                        numel,
                        dtype,
                        bkcl_red_type,
                        stream),
401 402
        BKCL_SUCCESS,
        platform::errors::PreconditionNotMet("BKCL all reduce failed"));
403 404 405 406 407 408 409
#else
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "PaddlePaddle should be compiled with XPU."));
#endif
  }
};

410 411
template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
412 413
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
414
    if (ctx.HasInput("Cond")) {
415
      auto cond = ctx.Input<phi::DenseTensor>("Cond");
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
      auto place = cond->place();
      PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
                        true,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` tensor should be on cpu place"));
      PADDLE_ENFORCE_EQ(cond->numel(),
                        1,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` should be shape [1]"));
      if (!cond->data<bool>()[0]) {
        VLOG(4) << "Skip all reduce Op since cond is 0";
        return;
      }
    }

431
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
432 433
    auto in = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
434
    int rid = ctx.Attr<int>("ring_id");
435

436
    auto place = ctx.GetPlace();
437 438
    ncclDataType_t dtype =
        platform::ToNCCLDataType(framework::TransToProtoVarType(in->dtype()));
439
    int64_t numel = in->numel();
440
    const void* sendbuff = in->data<T>();
441 442 443
    out->Resize(in->dims());
    void* recvbuff = out->mutable_data<T>(place);

444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
    auto map = distributed::ProcessGroupMapFromGid::getInstance();
    if (map->has(rid)) {
      // Use ProcessGroup
      distributed::ProcessGroup* pg = map->get(rid);
      std::vector<phi::DenseTensor> in_tensor;
      std::vector<phi::DenseTensor> out_tensor;
      in_tensor.push_back(*in);
      out_tensor.push_back(*out);

      distributed::AllreduceOptions opts;
      switch (red_type) {
        case kRedSum:
          opts.reduce_op = distributed::ReduceOp::SUM;
          break;

        case kRedMax:
          opts.reduce_op = distributed::ReduceOp::MAX;
          break;

        case kRedMin:
          opts.reduce_op = distributed::ReduceOp::MIN;
          break;

        case kRedProd:
          opts.reduce_op = distributed::ReduceOp::PRODUCT;
          break;

        default:
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Invalid reduce type: %d", red_type));
      }

      auto task = pg->AllReduce(in_tensor, out_tensor, opts);
      task->Wait();
      return;
    }

481
    auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
482

483
    gpuStream_t stream = nullptr;
484
    if (ctx.Attr<bool>("use_calc_stream")) {
485 486 487 488
      // should not use global ctx for calc stream.
      // auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
      // stream = static_cast<phi::GPUContext*>(dev_ctx)->stream();
      stream = ctx.cuda_device_context().stream();
489 490 491
    } else {
      stream = comm->stream();
    }
492 493 494 495
    VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel
             << ", redtype:" << static_cast<int>(red_type)
             << ", dtype:" << dtype << ", comm:" << comm
             << ", stream:" << stream;
496

497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515
    ncclRedOp_t nccl_red_type = ncclSum;
    switch (red_type) {
      case kRedSum:
        nccl_red_type = ncclSum;
        break;

      case kRedMax:
        nccl_red_type = ncclMax;
        break;

      case kRedMin:
        nccl_red_type = ncclMin;
        break;

      case kRedProd:
        nccl_red_type = ncclProd;
        break;

      default:
M
MRXLT 已提交
516 517
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid reduce type: %d", red_type));
518 519
    }

520
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(
521
        sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream));
522
#else
M
MRXLT 已提交
523 524
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with GPU."));
525 526 527 528
#endif
  }
};

Z
zn 已提交
529 530 531 532 533
template <ReduceType red_type, typename T>
class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_CNCL)
534 535
    auto in = ctx.Input<phi::DenseTensor>("X");
    auto out = ctx.Output<phi::DenseTensor>("Out");
Z
zn 已提交
536

537
    if (ctx.HasInput("Cond")) {
538
      auto cond = ctx.Input<phi::DenseTensor>("Cond");
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
      auto place = cond->place();
      PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
                        true,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` tensor should be on cpu place"));
      PADDLE_ENFORCE_EQ(cond->numel(),
                        1,
                        platform::errors::PreconditionNotMet(
                            "The input `cond` should be shape [1]"));
      if (!cond->data<bool>()[0]) {
        VLOG(4) << "Skip all reduce Op since cond is 0";
        return;
      }
    }

Z
zn 已提交
554 555
    auto place = ctx.GetPlace();
    cnclDataType_t dtype =
Z
zn 已提交
556
        platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype()));
Z
zn 已提交
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604
    int64_t numel = in->numel();
    const void* sendbuff = in->data<T>();
    out->Resize(in->dims());
    void* recvbuff = out->mutable_data<T>(place);

    int rid = ctx.Attr<int>("ring_id");
    auto comm = platform::CNCLCommContext::Instance().Get(rid, place);

    mluStream stream = nullptr;
    if (ctx.Attr<bool>("use_calc_stream")) {
      auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
      stream = static_cast<platform::MLUDeviceContext*>(dev_ctx)->stream();
    } else {
      stream = comm->stream();
    }

    cnclReduceOp_t cncl_red_type = cnclSum;
    switch (red_type) {
      case kRedSum:
        cncl_red_type = cnclSum;
        break;

      case kRedMax:
        cncl_red_type = cnclMax;
        break;

      case kRedMin:
        cncl_red_type = cnclMin;
        break;

      case kRedProd:
        cncl_red_type = cnclProd;
        break;

      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Invalid reduce type: %d", red_type));
    }

    PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(
        sendbuff, recvbuff, numel, dtype, cncl_red_type, comm->comm(), stream));
#else
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "PaddlePaddle should compile with MLU."));
#endif
  }
};

605 606 607 608 609 610 611
class CAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() {
    AddInput("X", "(Tensor), tensor to be allreduced.");
    AddOutput("Out", "(Tensor) the allreduced result.");
    AddAttr<int>("ring_id", "(int default 0) communication ring id.")
        .SetDefault(0);
612 613 614 615
#if defined(PADDLE_WITH_ASCEND_CL)
    AddAttr<std::string>("tag", "(string default tag) tag for all reduce.")
        .SetDefault("tag");
#endif
616 617 618 619
    AddAttr<bool>(
        "use_calc_stream",
        "(bool default false) eject CUDA operations to calculation stream.")
        .SetDefault(false);
L
lilong12 已提交
620 621 622 623 624 625
    AddAttr<bool>(
        "use_model_parallel",
        "(bool default false) use this op with model parallel mode. In model "
        "parallel mode, the backward is c_identity which returns itself for "
        "c_allreduce_sum.")
        .SetDefault(false);
626 627 628 629 630 631 632
    AddComment(string::Sprintf(R"DOC(
CAllReduce %s Operator

Call collective AllReduce with reduce type %s. If input and output are
the same variable, in-place allreduce will be used.
Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/usage/operations.html#allreduce
)DOC",
633 634
                               GetName(),
                               GetName()));
635
    ExtraMake();
636 637 638 639
  }

 protected:
  virtual std::string GetName() const = 0;
640
  virtual void ExtraMake() {}
641 642
};

643 644
}  // namespace operators
}  // namespace paddle