提交 e3e981cc 编写于 作者: M Megvii Engine Team

test(mge/collective_comm): fix collective_comm test and add data parallel test

GitOrigin-RevId: 9209e779739f2a8e6c91efee78254d62010dabec
上级 d3d9018f
......@@ -802,7 +802,7 @@ void CollectiveComm::init_output_static_infer_desc() {
if (m_param.mode == Param::Mode::SCATTER) {
dest[0] /= nr_devices();
}
if (!m_output_shape.valid()) {
if (is_root() && !m_output_shape.valid()) {
m_output_shape = dest;
m_group_client->set_output_shape(m_key, dest);
}
......@@ -824,7 +824,7 @@ void CollectiveComm::init_output_static_infer_desc() {
mgb_assert(output().size() == 1);
if (is_root()) {
if (is_root() || input().size() > 0) {
mgb_assert(input().size() == 1);
mgr.register_shape_infer(output(0),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册