提交 9389a805 编写于 作者: M Megvii Engine Team

fix(mge/oprmm): fix grad for collective comm

GitOrigin-RevId: e55dc47a6eb8f70955c94743ba274feb2c141624
上级 6f581906
......@@ -14,6 +14,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
from megengine.distributed.helper import get_device_count_by_fork
......@@ -187,3 +188,10 @@ def test_synchronized():
for p in procs:
p.join(20)
assert p.exitcode == 0
def test_oprmm_hashable():
lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
assert lhs == rhs
assert hash(lhs) == hash(rhs)
......@@ -71,7 +71,7 @@ OP_TRAIT_REG(CollectiveComm, CollectiveComm, opr::CollectiveComm)
.make_from_op_node(make_from_op_node)
.fallback();
} // anonymous namespace
#endif // MGB_ENABLE_OPR_MM
bool CollectiveComm::is_same_st(const Hashable& another) const{
auto* comm_opr = another.try_cast_final<CollectiveComm>();
......@@ -100,18 +100,6 @@ size_t CollectiveComm::hash() const{
return xxhash.digest();
}
#else
bool CollectiveComm::is_same_st(const Hashable& another) const{
return OpDef::is_same_st(another);
}
size_t CollectiveComm::hash() const{
return OpDef::hash();
}
#endif // MGB_ENABLE_OPR_MM
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CollectiveComm);
} // namespace imperative
......
......@@ -64,11 +64,8 @@ public:
}
bool is_same_st(const Hashable& rhs) const override {
auto* pps = rhs.try_cast_final<ParamPackSplit>();
if(pps == nullptr){
return false;
}
return offsets == pps->offsets && shapes == pps->shapes;
auto&& pps = rhs.cast_final_safe<ParamPackSplit>();
return offsets == pps.offsets && shapes == pps.shapes;
}
};
......@@ -94,11 +91,8 @@ public:
}
bool is_same_st(const Hashable& rhs) const override {
auto* ppc = rhs.try_cast_final<ParamPackConcat>();
if(ppc == nullptr){
return false;
}
return offsets == ppc->offsets;
auto&& ppc = rhs.cast_final_safe<ParamPackConcat>();
return offsets == ppc.offsets;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册