diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index b8ea6fdfc7b9c1ed86cbe8b1ba4b4ce4476454ac..c604f8c54e6745d161f1fefd538bd24498b8995c 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -802,7 +802,7 @@ void CollectiveComm::init_output_static_infer_desc() { if (m_param.mode == Param::Mode::SCATTER) { dest[0] /= nr_devices(); } - if (!m_output_shape.valid()) { + if (is_root() && !m_output_shape.valid()) { m_output_shape = dest; m_group_client->set_output_shape(m_key, dest); } @@ -824,7 +824,7 @@ void CollectiveComm::init_output_static_infer_desc() { mgb_assert(output().size() == 1); - if (is_root()) { + if (is_root() || input().size() > 0) { mgb_assert(input().size() == 1); mgr.register_shape_infer(output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input});