提交 9389a805 编写于 作者: M Megvii Engine Team

fix(mge/oprmm): fix grad for collective comm

GitOrigin-RevId: e55dc47a6eb8f70955c94743ba274feb2c141624
上级 6f581906
...@@ -14,6 +14,7 @@ import pytest ...@@ -14,6 +14,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
...@@ -187,3 +188,10 @@ def test_synchronized(): ...@@ -187,3 +188,10 @@ def test_synchronized():
for p in procs: for p in procs:
p.join(20) p.join(20)
assert p.exitcode == 0 assert p.exitcode == 0
def test_oprmm_hashable():
lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
assert lhs == rhs
assert hash(lhs) == hash(rhs)
...@@ -71,7 +71,7 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm) ...@@ -71,7 +71,7 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
.make_from_op_node(make_from_op_node) .make_from_op_node(make_from_op_node)
.fallback(); .fallback();
} // anonymous namespace } // anonymous namespace
#endif // MGB_ENABLE_OPR_MM
bool CollectiveComm::is_same_st(const Hashable& another) const{ bool CollectiveComm::is_same_st(const Hashable& another) const{
auto* comm_opr = another.try_cast_final<CollectiveComm>(); auto* comm_opr = another.try_cast_final<CollectiveComm>();
...@@ -100,18 +100,6 @@ size_t CollectiveComm::hash() const{ ...@@ -100,18 +100,6 @@ size_t CollectiveComm::hash() const{
return xxhash.digest(); return xxhash.digest();
} }
#else
bool CollectiveComm::is_same_st(const Hashable& another) const{
return OpDef::is_same_st(another);
}
size_t CollectiveComm::hash() const{
return OpDef::hash();
}
#endif // MGB_ENABLE_OPR_MM
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm); MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
} // namespace imperative } // namespace imperative
......
...@@ -64,11 +64,8 @@ public: ...@@ -64,11 +64,8 @@ public:
} }
bool is_same_st(const Hashable& rhs) const override { bool is_same_st(const Hashable& rhs) const override {
auto* pps = rhs.try_cast_final<ParamPackSplit>(); auto&& pps = rhs.cast_final_safe<ParamPackSplit>();
if(pps == nullptr){ return offsets == pps.offsets && shapes == pps.shapes;
return false;
}
return offsets == pps->offsets && shapes == pps->shapes;
} }
}; };
...@@ -94,11 +91,8 @@ public: ...@@ -94,11 +91,8 @@ public:
} }
bool is_same_st(const Hashable& rhs) const override { bool is_same_st(const Hashable& rhs) const override {
auto* ppc = rhs.try_cast_final<ParamPackConcat>(); auto&& ppc = rhs.cast_final_safe<ParamPackConcat>();
if(ppc == nullptr){ return offsets == ppc.offsets;
return false;
}
return offsets == ppc->offsets;
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册