collective_comm.cpp 28.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * \file src/opr-mm/impl/collective_comm.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/collective_comm.h"

#include "megbrain/comp_node_env.h"
#include "megbrain/graph/event.h"
#include "megbrain/graph/grad_impl.h"
17 18
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
19 20 21 22 23 24 25 26 27 28 29
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/megray_helper.h"
#include "megbrain/opr/group_manager.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/version_symbol.h"

using namespace mgb;
using namespace opr;

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);

30 31 32 33
#define FOREACH_MODE(cb)                                                    \
    cb(ALL_REDUCE_SUM) cb(ALL_REDUCE_MAX) cb(ALL_REDUCE_MIN) cb(BROADCAST)  \
            cb(REDUCE_SUM) cb(ALL_GATHER) cb(REDUCE_SCATTER_SUM) cb(GATHER) \
            cb(SCATTER) cb(ALL_TO_ALL)
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

namespace {

const char* get_param_name(CollectiveComm::Param param) {
    using Mode = CollectiveComm::Param::Mode;
    switch (param.mode) {
#define C(_m)      \
    case Mode::_m: \
        return #_m;
        FOREACH_MODE(C)
#undef C
        default:
            mgb_throw(MegBrainError, "bad CollectiveComm mode");
    }
}

cudaStream_t get_stream(VarNode* var) {
    return CompNodeEnv::from_comp_node(var->comp_node()).cuda_env().stream;
}
}  // anonymous namespace

55 56
/* ================= ModeTrait ================= */

57 58 59 60 61 62 63 64
class CollectiveComm::ModeTrait {
    class BROADCAST;
    class REDUCE_SUM;
    class REDUCE_SCATTER_SUM;
    class ALL_GATHER;
    class ALL_REDUCE_SUM;
    class ALL_REDUCE_MAX;
    class ALL_REDUCE_MIN;
65 66 67
    class GATHER;
    class SCATTER;
    class ALL_TO_ALL;
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

    class ReducedBasedTrait;
    class AllReduceBase;
    class ReduceBase;

protected:
    using Mode = Param::Mode;

    static void chk_shape_equal(const TensorShapeArray& shp) {
        for (size_t i = 1; i < shp.size(); ++i) {
            mgb_throw_if(!shp[0].eq_shape(shp[i]), GraphError,
                         "input shapes should be equal");
        }
    }

public:
    virtual ~ModeTrait() = default;

    /*!
     * \brief the vars on whose comp node the computing should be performed
     * if None, output vars would be used
     */
    virtual Maybe<VarNodeArray> comp_vars(CollectiveComm* opr) {
        return None;
    }

94 95 96
    VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const {
        auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode();
        SymbolVarArray og_syms;
97 98 99 100

        if (out_grad != nullptr) {
            og_syms.push_back(out_grad);
        }
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

        auto&& cn = opr->output(0)->comp_node();

        auto gvar = CollectiveComm::make(
                og_syms, opr->owner_graph(), opr->key() + ":grad",
                opr->nr_devices(), opr->is_root(), opr->rank(), false,
                opr->group_client(), mode, opr->dtype(), opr->backend(), {cn});

        return gvar[0].node();
    }

    virtual VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const {
        mgb_throw(MegBrainError,
                  "only all_reduce all_to_all all_gather reduce_scatter "
                  "support local_grad");
    }

    virtual VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const {
        if (opr->local_grad()){
            return local_grad(out_grad, opr);
        } else {
            return full_grad(out_grad, opr);
        }
    }

    VarNode* zeros(mgb::cg::ComputingGraph &graph, CompNode node, const SymbolVar& shape,
                 DType dtype) const {
        auto zero = SymbolVar::make_scalar(0, graph, node);
        auto zero_tensor = opr::TypeCvt::make(zero, dtype).broadcast(shape);
        return zero_tensor.node();
    }

133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
    virtual void get_output_var_shape(const CollectiveComm* opr,
                                      const TensorShapeArray& ishp,
                                      TensorShapeArray& oshp) = 0;

    virtual void exec(CollectiveComm* opr) = 0;

    //! gradient mode
    virtual Mode grad_mode() = 0;

    static ModeTrait& from_mode(Mode mode);
};

class CollectiveComm::ModeTrait::ALL_GATHER : public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        chk_shape_equal(ishp);
        auto soshp = ishp[0];
        soshp[0] *= opr->nr_devices();
        for (auto& i : oshp)
            i = soshp;
    }

    void exec(CollectiveComm* opr) override {
        auto ivar = opr->input(0), ovar = opr->output(0);
        auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
        mgb_assert(ivar->comp_node().mem_node() ==
                   ovar->comp_node().mem_node());
        auto status = opr->m_megray_comm->all_gather(
                (void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
                iv.shape().total_nr_elems(),
                get_megray_dtype(iv.dtype()),
                opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_gather failed");
    }

    Mode grad_mode() override { return Mode::REDUCE_SCATTER_SUM; }
170 171 172 173 174 175 176 177 178 179 180

    VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        auto nr_devices = opr->nr_devices();
        auto rank = opr->rank();
        opr::Subtensor::IndexDesc axis;
        auto shape0 = opr::GetVarShape::make(out_grad, 0);
        axis.push_back({0, shape0 * rank / (int)nr_devices,
                        shape0 * (rank + 1) / (int)nr_devices});
        auto grad = opr::Subtensor::make(out_grad, axis);
        return grad.node();
    }
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
};

class CollectiveComm::ModeTrait::REDUCE_SCATTER_SUM : public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        chk_shape_equal(ishp);
        auto soshp = ishp[0];
        mgb_throw_if(soshp.shape[0] % opr->nr_devices(), GraphError,
                     "input size can not be divided equally: "
                     "size=%zu parts=%zu",
                     soshp[0], ishp.size());
        soshp[0] /= opr->nr_devices();
        for (auto& i : oshp)
            i = soshp;
    }

    void exec(CollectiveComm* opr) override {
        auto ivar = opr->input(0), ovar = opr->output(0);
        auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
        mgb_assert(ivar->comp_node().mem_node() ==
                   ovar->comp_node().mem_node());

        size_t buff_len = ov.shape().total_nr_elems();// * opr->m_nr_devices;
        auto status = opr->m_megray_comm->reduce_scatter(
                (void*)iv.raw_ptr(), (void*)ov.raw_ptr(), buff_len,
                get_megray_dtype(ov.dtype()), MegRay::ReduceOp::MEGRAY_SUM,
                opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay reduce_scatter failed");
    }

    Mode grad_mode() override { return Mode::ALL_GATHER; }

214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
    VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNodeArray grads;
        auto zeros_tensor =
                zeros(*out_grad->owner_graph(), out_grad->comp_node(),
                      opr::GetVarShape::make(out_grad), out_grad->dtype());
        for (size_t i = 0;i < opr->nr_devices();i++) {
            if (i == opr->rank()) {
                grads.push_back(out_grad);
            } else {
                grads.push_back(zeros_tensor);
            }
        }
        auto grad = opr::Concat::make(grads, 0);
        return grad.node();
    }
};
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260

class CollectiveComm::ModeTrait::ReducedBasedTrait {
protected:
    ~ReducedBasedTrait() = default;

    virtual MegRay::ReduceOp op() const = 0;
};

class CollectiveComm::ModeTrait::AllReduceBase : public ReducedBasedTrait,
                                                   public ModeTrait {
    void get_output_var_shape(const CollectiveComm*,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        chk_shape_equal(ishp);
        oshp = ishp;
    }

    void exec(CollectiveComm* opr) override {
        auto ivar = opr->input(0), ovar = opr->output(0);
        auto &&iv = ivar->dev_tensor(), &&ov = ovar->dev_tensor();
        mgb_assert(ivar->comp_node().mem_node() ==
                   ovar->comp_node().mem_node());
        auto status = opr->m_megray_comm->all_reduce(
                (void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
                iv.shape().total_nr_elems(),
                get_megray_dtype(iv.dtype()), op(),
                opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_reduce failed");
    }

    Mode grad_mode() override { return Mode::ALL_REDUCE_SUM; }
261 262 263 264 265 266

public:
    VarNode* local_grad(VarNode* out_grad,
                        const CollectiveComm* opr) const override {
        return out_grad;
    }
267 268 269 270 271 272 273 274
};

class CollectiveComm::ModeTrait::ALL_REDUCE_SUM final : public AllReduceBase {
    MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }
};

class CollectiveComm::ModeTrait::ALL_REDUCE_MAX final : public AllReduceBase {
    MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MAX; }
275 276 277 278 279 280 281 282 283 284 285 286 287 288

    VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNode* grad;
        if (opr->local_grad()) {
            grad = local_grad(out_grad, opr);
        } else {
            grad = full_grad(out_grad, opr);
        }

        grad = opr::Elemwise::make({opr->output(0), opr->input(0), grad},
                                   Elemwise::Mode::COND_LEQ_MOV)
                       .node();
        return grad;
    }
289 290 291 292
};

class CollectiveComm::ModeTrait::ALL_REDUCE_MIN final : public AllReduceBase {
    MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_MIN; }
293 294 295 296 297 298 299 300 301 302 303 304 305 306

    VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNode* grad;
        if (opr->local_grad()) {
            grad = local_grad(out_grad, opr);
        } else {
            grad = full_grad(out_grad, opr);
        }

        grad = opr::Elemwise::make({opr->input(0), opr->output(0), grad},
                                   Elemwise::Mode::COND_LEQ_MOV)
                       .node();
        return grad;
    }
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
};

class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
                                                public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        MGB_MARK_USED_VAR(opr);
        chk_shape_equal(ishp);
        if (opr->is_root()) {
            oshp[0] = ishp[0];
        } else {
            oshp[0] = TensorShape{1};
        }
    }

    void exec(CollectiveComm* opr) override {
        auto ovar = opr->output(0);
        auto&& iv = opr->input(0)->dev_tensor();
        void* recvbuf = nullptr;
        if (opr->is_root()) {
            recvbuf = ovar->dev_tensor().raw_ptr();
        }
        auto status = opr->m_megray_comm->reduce(
                (void*)iv.raw_ptr(), recvbuf,
                iv.shape().total_nr_elems(),
                get_megray_dtype(iv.dtype()), op(),
                opr->m_root, opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay reduce failed");
    }
};

class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
    MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }

342 343 344 345 346
    VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNode* input = opr->is_root() ? out_grad : nullptr;
        return full_grad(input, opr);
    }

347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
    Mode grad_mode() override { return Mode::BROADCAST; }
};

class CollectiveComm::ModeTrait::BROADCAST : public ModeTrait {
    void get_output_var_shape(const CollectiveComm*,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        mgb_assert(false, "BROADCAST should not use get_output_var_shape");
    }

    void exec(CollectiveComm* opr) override {
        auto ovar = opr->output(0);
        auto&& ov = ovar->dev_tensor();
        mgb_assert(opr->input().size() < 2,
                   "input size of BROADCAST must be either 0 or 1");
        void* buff;
        DType datatype;
        size_t length;
        if (opr->is_root()) {
            auto ivar = opr->input(0);
            auto&& iv = ivar->dev_tensor();
            datatype = iv.dtype();
            buff = (void*)iv.raw_ptr();
            length = iv.shape().total_nr_elems();
        } else {
            buff = NULL;
            datatype = ov.dtype();
            length = ov.shape().total_nr_elems();
        }
        auto status = opr->m_megray_comm->broadcast(
                buff, (void*)ov.raw_ptr(), length,
                get_megray_dtype(datatype), opr->m_root,
                opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay broadcast failed");
    }

    Mode grad_mode() override { return Mode::REDUCE_SUM; }
};

386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        MGB_MARK_USED_VAR(opr);
        chk_shape_equal(ishp);
        if (opr->is_root()) {
            oshp[0] = ishp[0];
            oshp[0][0] *= opr->nr_devices();
        } else {
            oshp[0] = TensorShape{1};
        }
    }

    void exec(CollectiveComm* opr) override {
        auto&& iv = opr->input(0)->dev_tensor();
        void* recvbuf = nullptr;
        if (opr->is_root()) {
            recvbuf = opr->output(0)->dev_tensor().raw_ptr();
        }
        auto status = opr->m_megray_comm->gather(
                (void*)iv.raw_ptr(), recvbuf, iv.shape().total_nr_elems(),
                get_megray_dtype(iv.dtype()), opr->m_root, opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed");
    }

412 413 414 415 416
    VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNode* input = opr->is_root() ? out_grad : nullptr;
        return full_grad(input, opr);
    }

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
    Mode grad_mode() override { return Mode::SCATTER; }
};

class CollectiveComm::ModeTrait::SCATTER : public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        mgb_throw(MegBrainError, "SCATTER should not use get_output_var_shape");
    }

    void exec(CollectiveComm* opr) override {
        auto&& ov = opr->output(0)->dev_tensor();
        void* sendbuf = nullptr;
        void* recvbuf = ov.raw_ptr();
        if (opr->is_root()) {
            sendbuf = opr->input(0)->dev_tensor().raw_ptr();
        }
        auto status = opr->m_megray_comm->scatter(
                sendbuf, recvbuf, ov.shape().total_nr_elems(),
                get_megray_dtype(ov.dtype()), opr->m_root, opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay scatter failed");
    }

    Mode grad_mode() override { return Mode::GATHER; }
};

class CollectiveComm::ModeTrait::ALL_TO_ALL : public ModeTrait {
    void get_output_var_shape(const CollectiveComm* opr,
                              const TensorShapeArray& ishp,
                              TensorShapeArray& oshp) override {
        chk_shape_equal(ishp);
        oshp = ishp;
    }

    void exec(CollectiveComm* opr) override {
        auto&& iv = opr->input(0)->dev_tensor();
        auto&& ov = opr->output(0)->dev_tensor();
        auto status = opr->m_megray_comm->all_to_all(
                (void*)iv.raw_ptr(), (void*)ov.raw_ptr(),
                iv.shape().total_nr_elems() / opr->nr_devices(),
                get_megray_dtype(iv.dtype()), opr->megray_ctx());
        mgb_assert(status == MegRay::MEGRAY_OK, "MegRay all_to_all failed");
    }

    Mode grad_mode() override { return Mode::ALL_TO_ALL; }
462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479

    VarNode* local_grad(VarNode* out_grad, const CollectiveComm* opr) const override {
        VarNodeArray grads;
        auto grad_shape = opr::GetVarShape::make(out_grad);
        auto zeros_tensor =
                zeros(*out_grad->owner_graph(), out_grad->comp_node(),
                      grad_shape, out_grad->dtype());

        auto nr_devices = opr->nr_devices();
        auto rank = opr->rank();
        opr::Subtensor::IndexDesc axis;
        auto shape0 = opr::GetVarShape::make(out_grad, 0);
        axis.push_back({0, shape0 * rank / (int)nr_devices,
                        shape0 * (rank + 1) / (int)nr_devices});
        auto sub_grad = opr::Subtensor::make(out_grad, axis);

        return opr::SetSubtensor::make(zeros_tensor, sub_grad, axis).node();
    }
480 481
};

482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
CollectiveComm::ModeTrait& CollectiveComm::ModeTrait::from_mode(Mode mode) {
    switch (mode) {
#define c(_m)          \
    case Mode::_m: {   \
        static _m ins; \
        return ins;    \
    }
        FOREACH_MODE(c)
        default:
            mgb_assert(0);
#undef c
    }
}

/* ================= CollectiveComm ================= */

CollectiveComm::CollectiveComm(
        VarNodeArray inputs, ComputingGraph* const graph,
500
        const std::string& key, const size_t nr_devices, const bool is_root,
501 502 503
        const int rank, const bool local_grad,
        std::shared_ptr<GroupClient> group_client, const Param& param,
        const DType& dtype, const std::string& backend,
504 505 506 507 508 509 510 511 512
        const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
        const OperatorNodeConfig& config,
        const std::shared_ptr<DTypeScalar>& disable)
        : Super{graph, config, get_param_name(param), inputs},
          m_param{param},
          m_dtype(dtype),
          m_backend(backend),
          m_group_client{std::move(group_client)},
          m_nr_devices(nr_devices),
513
          m_is_root(is_root),
514
          m_rank(rank),
515
          m_local_grad(local_grad),
516 517 518
          m_key(key),
          m_dev_buffers(dev_buffer_arr),
          m_disable{disable} {
519 520 521 522
    // add input
    mgb_assert(inputs.size() <= 1, "one or zero input expected, got %zu", inputs.size());
    if (inputs.size() > 0) {
        mgb_assert(inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
523
                   "CollectiveComm currectly only supports CUDA");
524
        add_input({inputs[0]});
525 526
    }

527 528
    // add output
    add_output(ssprintf("%s:%s", get_param_name(param), key.c_str()));
529

530 531 532 533 534 535 536 537 538
    // set comp node
    const auto& cns = config.comp_node();
    mgb_assert(cns.size() <= 1, "one or zero comp node expected, got %zu", cns.size());
    if (cns.size() > 0) {
        mgb_assert(cns[0].device_type() == CompNode::DeviceType::CUDA,
                   "CollectiveComm currectly only supports CUDA");
        output(0)->comp_node(cns[0]);
    } else {
        output(0)->comp_node(inputs[0]->comp_node());
539 540
    }

541
    // set debug flag
542 543 544 545 546
    const char* c_debug = MGB_GETENV("MGE_MM_OPR_DEBUG");
    if (c_debug != nullptr and strcmp(c_debug, "1") == 0) {
        m_debug_mode = true;
    }

547
    // deduplication
548 549 550 551 552 553 554 555
    add_equivalence_component<PODHash<Param>>(&m_param);
    add_equivalence_component<PODHash<size_t>>(&m_nr_devices);
    m_hash = XXHash{}.update(key.data(), key.size() * sizeof(char)).digest();
    add_equivalence_component<PODHash<size_t>>(&m_hash);
}

SymbolVarArray CollectiveComm::make(
        const SymbolVarArray& inputs, ComputingGraph* const graph,
556
        const std::string& key, const size_t nr_devices, const bool is_root,
557 558 559
        const int rank, const bool local_grad,
        std::shared_ptr<GroupClient> group_client, const Param& param,
        const DType& dtype, const std::string& backend,
560
        const OperatorNodeConfig& config,
561 562 563
        const std::shared_ptr<DTypeScalar>& disable) {
    SmallVector<std::shared_ptr<DeviceTensorND>> dev_buffer_arr(nr_devices,
                                                                nullptr);
564 565
    return make(inputs, graph, key, nr_devices, is_root, rank, local_grad,
                group_client, dev_buffer_arr, param, dtype, backend, config);
566 567 568 569
}

SymbolVarArray CollectiveComm::make(
        const SymbolVarArray& inputs, ComputingGraph* const graph,
570
        const std::string& key, const size_t nr_devices, const bool is_root,
571 572
        const int rank, const bool local_grad,
        std::shared_ptr<GroupClient> group_client,
573 574 575 576 577 578
        const SmallVector<std::shared_ptr<DeviceTensorND>>& dev_buffer_arr,
        const Param& param, const DType& dtype, const std::string& backend,
        const OperatorNodeConfig& config,
        const std::shared_ptr<DTypeScalar>& disable) {
    auto inpvars = cg::to_var_node_array(inputs);
    auto opr = graph->insert_opr(std::make_unique<CollectiveComm>(
579 580 581
            inpvars, graph, key, nr_devices, is_root, rank, local_grad,
            std::move(group_client), param, dtype, backend, dev_buffer_arr,
            config, disable));
582 583 584 585 586 587 588
    mgb_assert(!opr->output().empty());
    return cg::to_symbol_var_array(opr->output());
}

void CollectiveComm::opr_register() {
    if (m_init)
        return;
589

590
    auto&& comp_node = output(0)->comp_node();
591 592
    bool use_cache = output(0)->owner_graph()->options().imperative_proxy_graph;
    struct GroupManager::RegisterInfo reg_info;
593

594 595 596 597 598 599 600 601 602 603
    if (use_cache and RegInfoCache::has_info(m_key)) {
        reg_info = RegInfoCache::get_info(m_key);
    } else {
        reg_info = m_group_client->opr_register(
                m_key, m_nr_devices, m_is_root, m_rank,
                comp_node.get_uid());
        if (use_cache) {
            RegInfoCache::set_info(m_key, reg_info);
        }
    }
604

605 606
    m_rank = reg_info.rank;
    m_root = reg_info.root_rank;
607

608
    m_megray_comm = MegRayCommBuilder::get_megray_comm(
609
            reg_info.hash, m_key, m_nr_devices, m_rank,
610 611
            get_megray_backend(m_backend), m_group_client);

612 613
    m_megray_ctx = MegRay::CudaContext::make(get_stream(output(0)));

614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
    m_init = true;
}

void CollectiveComm::add_input_layout_constraint() {
    // Enable shape infer *after* static infer phase. This is only used by
    // BROADCAST operation.
    m_enable_shape_infer = true;
    for (auto i : input()) {
        i->add_layout_constraint_contiguous();
    }
}

void CollectiveComm::get_output_var_shape(const TensorShapeArray& inp_shape,
                                            TensorShapeArray& out_shape) const {
    ModeTrait::from_mode(m_param.mode)
            .get_output_var_shape(const_cast<CollectiveComm*>(this),
                                  inp_shape, out_shape);
}

633
void CollectiveComm::init_output_comp_node() {}
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681

void CollectiveComm::init_output_mem_plan(bool dynamic) {
    for (size_t i = 0; i < output().size(); i++) {
        if (m_dev_buffers[i]) {
            output(i)->init_mem_plan(m_dev_buffers[i].get());
        } else {
            if (is_static_var_storage(output(i)) == !dynamic &&
                !output(i)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC))
                output(i)->init_mem_plan();
        }
    }
}

void CollectiveComm::mem_plan_fwd_in2out_writable() {
    if (m_param.mode == Param::Mode::ALL_REDUCE_SUM) {
        for (size_t i = 0; i < output().size(); ++i) {
            output(i)->set_fwd_in2out_writable(input(i));
        }
    }
}

cg::OperatorNodeBase::NodeProp* CollectiveComm::do_make_node_prop() const {
    auto prop = OperatorNodeBase::do_make_node_prop();
    prop->add_flag(NodeProp::Flag::CROSS_COMP_NODE_MEMORY);
    prop->add_flag(NodeProp::Flag::NO_AUTOMATIC_DUP);
    return prop;
}

void CollectiveComm::do_execute(ExecEnv& env) {
    auto&& trait = ModeTrait::from_mode(m_param.mode);
    mgb_assert(owner_graph()->options().async_exec_level,
               "collective comm must be used with async dispatch");
    mgb_assert(output().size() == 1,
               "collective comm only support exactly one output");

    auto disable = m_disable->get_cast<int>();
    if (disable == 1)
        return;
    mgb_assert(disable == 0,
               "disable flag on CollectiveComm can only be 0 or 1,"
               " got %d actually.",
               disable);

    auto cn = output(0)->comp_node();
    auto runner = [this, cn, &trait] {
        opr_register();
        cn.activate();

682 683 684 685 686
        if (m_debug_mode) {
            mgb_log_debug("collective comm: executing %s, rank = %d, key = %s",
                    cname(), rank(), key().c_str());
        }

687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712
        owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(this, cn);
        trait.exec(this);
        owner_graph()->event().signal_inplace<cg::event::AfterKernel>(this, cn);
    };
    env.dispatch_on_comp_node(cn, runner);
}

void CollectiveComm::on_output_comp_node_stream_changed() {}

void CollectiveComm::init_output_dtype() {
    if (m_dtype.valid()) {
        for (size_t i = 0; i < input().size(); ++i) {
            mgb_assert(m_dtype == input(i)->dtype(),
                       "any given input's dtype should be identical to that "
                       "specified from opr's argument");
        }
        for (auto i : output()) {
            if (!i->dtype().valid())
                i->dtype(m_dtype);
        }
    } else {
        Super::init_output_dtype();
    }
}

void CollectiveComm::init_output_static_infer_desc() {
713 714
    if (m_param.mode == Param::Mode::BROADCAST ||
        m_param.mode == Param::Mode::SCATTER) {
715 716 717 718 719
        using namespace cg::static_infer;
        auto&& mgr = owner_graph()->static_infer_manager();

        auto infer_shape_from_input = [this](TensorShape& dest, const InpVal& inp_val) {
            dest = inp_val.val[0].shape();
720 721 722
            if (m_param.mode == Param::Mode::SCATTER) {
                dest[0] /= nr_devices();
            }
723
            if (is_root() && !m_output_shape.valid()) {
724 725 726
                m_output_shape = dest;
                m_group_client->set_output_shape(m_key, dest);
            }
727 728 729 730
            return true;
        };

        auto get_shape_from_server = [this](TensorShape& dest, const InpVal&) {
731
            if (!m_enable_shape_infer && !owner_graph()->options().imperative_proxy_graph) {
732 733 734
                return false;
            }

735 736
            if (!m_output_shape.valid()) {
                m_output_shape = m_group_client->get_output_shape(m_key);
737
            }
738 739

            dest = m_output_shape.val();
740 741 742 743 744
            return true;
        };

        mgb_assert(output().size() == 1);

745
        if (is_root() || input().size() > 0) {
746 747 748 749 750 751 752 753 754 755 756 757 758
            mgb_assert(input().size() == 1);
            mgr.register_shape_infer(output(0),
                {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});
        } else {
            mgr.register_shape_infer(output(0),
                {SourceType::MUTABLE, {}, get_shape_from_server});
        }

    } else {
        Super::init_output_static_infer_desc();
    }
}

759 760 761 762
VarNode* CollectiveComm::grad(VarNode* out_grad) const {
    return ModeTrait::from_mode(m_param.mode).grad(out_grad, this);
}

763
#if MGB_ENABLE_GRAD
764 765 766 767
MGB_IMPL_OPR_GRAD(CollectiveComm) {
    mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad");
    return opr.grad(out_grad[0]);
}
768
#endif
769

770 771 772 773 774 775 776 777 778 779
/* ===================== shallow copy ===================== */

namespace mgb {
namespace opr {

cg::OperatorNodeBase* opr_shallow_copy_collective_mm(
        const serialization::OprShallowCopyContext& ctx,
        const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs,
        const OperatorNodeConfig& config) {
    auto&& opr = opr_.cast_final_safe<opr::CollectiveComm>();
780 781 782 783 784 785 786 787
    auto new_opr =
            CollectiveComm::make(
                    to_symbol_var_array(inputs), ctx.owner_graph(opr_, inputs),
                    opr.key(), opr.nr_devices(), opr.is_root(), opr.rank(),
                    opr.local_grad(), opr.group_client(), opr.dev_buffers(),
                    opr.param(), opr.dtype(), opr.backend(), config)[0]
                    .node()
                    ->owner_opr();
788 789
    new_opr->cast_final_safe<opr::CollectiveComm>().set_pack_hash(opr.pack_hash());
    return new_opr;
790 791 792 793 794 795 796
}
MGB_REG_OPR_SHALLOW_COPY(CollectiveComm, opr_shallow_copy_collective_mm);

}  // namespace opr
}  // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}