提交 5e912edd 编写于 作者: M Megvii Engine Team

fix(mgb/opr-mm): fix grad func of reduce and gather

GitOrigin-RevId: 4687faef9943460011737315cf052cc584e6988a
上级 e3e981cc
......@@ -139,7 +139,10 @@ public:
VarNode* full_grad(VarNode* out_grad, const CollectiveComm* opr) const {
auto mode = ModeTrait::from_mode(opr->param().mode).grad_mode();
SymbolVarArray og_syms;
og_syms.push_back(out_grad);
if (out_grad != nullptr) {
og_syms.push_back(out_grad);
}
auto&& cn = opr->output(0)->comp_node();
......@@ -401,6 +404,11 @@ class CollectiveComm::ModeTrait::ReduceBase : public ReducedBasedTrait,
class CollectiveComm::ModeTrait::REDUCE_SUM final : public ReduceBase {
MegRay::ReduceOp op() const override { return MegRay::ReduceOp::MEGRAY_SUM; }
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* input = opr->is_root() ? out_grad : nullptr;
return full_grad(input, opr);
}
Mode grad_mode() override { return Mode::BROADCAST; }
};
......@@ -484,6 +492,11 @@ class CollectiveComm::ModeTrait::GATHER : public ModeTrait {
mgb_assert(status == MegRay::MEGRAY_OK, "MegRay gather failed");
}
VarNode* grad(VarNode* out_grad, const CollectiveComm* opr) const override {
VarNode* input = opr->is_root() ? out_grad : nullptr;
return full_grad(input, opr);
}
Mode grad_mode() override { return Mode::SCATTER; }
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册