diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index c604f8c54e6745d161f1fefd538bd24498b8995c..9ae252226b214a1a10317bae904eb21b98b11402 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -139,7 +139,10 @@ public: VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); SymbolVarArray og_syms; - og_syms.push_back(out_grad); + + if (out_grad != nullptr) { + og_syms.push_back(out_grad); + } auto&& cn = opr->output(0)->comp_node(); @@ -401,6 +404,11 @@ class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; } + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* input = opr->is_root() ? out_grad : nullptr; + return full_grad(input, opr); + } + Mode grad_mode() override { return Mode::BROADCAST; } }; @@ -484,6 +492,11 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); } + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* input = opr->is_root() ? out_grad : nullptr; + return full_grad(input, opr); + } + Mode grad_mode() override { return Mode::SCATTER; } };