From 5e912eddbd959f0426182b578e05e1fb570ec274 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 18 Jul 2020 14:30:04 +0800 Subject: [PATCH] fix(mgb/opr-mm): fix grad func of reduce and gather GitOrigin-RevId: 4687faef9943460011737315cf052cc584e6988a --- src/opr-mm/impl/collective_comm.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index c604f8c5..9ae25222 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -139,7 +139,10 @@ public: VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const { auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode(); SymbolVarArray og_syms; - og_syms.push_back(out_grad); + + if (out_grad != nullptr) { + og_syms.push_back(out_grad); + } auto&& cn = opr->output(0)->comp_node(); @@ -401,6 +404,11 @@ class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait, class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase { MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; } + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* input = opr->is_root() ? out_grad : nullptr; + return full_grad(input, opr); + } + Mode grad_mode() override { return Mode::BROADCAST; } }; @@ -484,6 +492,11 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait { mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed"); } + VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override { + VarNode* input = opr->is_root() ? out_grad : nullptr; + return full_grad(input, opr); + } + Mode grad_mode() override { return Mode::SCATTER; } }; -- GitLab