From e3e981ccf092102c7dcbacfddc7255dbeda92a42 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 15 Jul 2020 17:05:44 +0800 Subject: [PATCH] test(mge/collective_comm): fix collective_comm test and add data parallel test GitOrigin-RevId: 9209e779739f2a8e6c91efee78254d62010dabec --- src/opr-mm/impl/collective_comm.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index b8ea6fdf..c604f8c5 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -802,7 +802,7 @@ void CollectiveComm::init_output_static_infer_desc() { if (m_param.mode == Param::Mode::SCATTER) { dest[0] /= nr_devices(); } - if (!m_output_shape.valid()) { + if (is_root() && !m_output_shape.valid()) { m_output_shape = dest; m_group_client->set_output_shape(m_key, dest); } @@ -824,7 +824,7 @@ void CollectiveComm::init_output_static_infer_desc() { mgb_assert(output().size() == 1); - if (is_root()) { + if (is_root() || input().size() > 0) { mgb_assert(input().size() == 1); mgr.register_shape_infer(output(0), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape_from_input}); -- GitLab