diff --git a/src/opr-mm/impl/collective_comm.cpp b/src/opr-mm/impl/collective_comm.cpp index 46e2b5417626330af04515689249f69223638695..b8ea6fdfc7b9c1ed86cbe8b1ba4b4ce4476454ac 100644 --- a/src/opr-mm/impl/collective_comm.cpp +++ b/src/opr-mm/impl/collective_comm.cpp @@ -810,7 +810,7 @@ void CollectiveComm::init_output_static_infer_desc() { }; auto get_shape_from_server = [this](TensorShape& dest, const InpVal&) { - if (!m_enable_shape_infer) { + if (!m_enable_shape_infer && !owner_graph()->options().imperative_proxy_graph) { return false; }