提交 4a5e3170 编写于 作者: M Megvii Engine Team

fix(mgb/opr-mm): remove stream -4 for CollectiveComm

GitOrigin-RevId: 41ea88dfa12dd385db122e0777ca18060105839f
上级 a599725c
......@@ -109,11 +109,8 @@ void SeqCompNodeOptimizerImpl::change_to_specific_stream(
type = any_strong_changed ?
StreamPropType::STRONG : StreamPropType::WEAK;
int copy_stream = CompNode::Stream::COPY;
int nccl_stream = CompNode::Stream::NCCL;
if (inp_streams.count(copy_stream))
stream = copy_stream;
else if (inp_streams.count(nccl_stream))
stream = nccl_stream;
mgb_assert(type != StreamPropType::NONE && stream != 0);
}
return prop_type_storage.second = StreamPropType{stream, type};
......@@ -188,8 +185,7 @@ void SeqCompNodeOptimizerImpl::register_stream_var(
mgb_assert(var->owner_graph() == m_owner_graph &&
(prop_type == StreamPropType::WEAK ||
prop_type == StreamPropType::STRONG));
mgb_assert(stream == CompNode::Stream::COPY || stream ==
CompNode::Stream::NCCL);
mgb_assert(stream == CompNode::Stream::COPY);
auto ins = m_var2prop_type.insert({var, {stream, prop_type}});
if (!ins.second) {
......
......@@ -207,8 +207,7 @@ class CompNode {
static constexpr int
COPY = -1,
REMOTE_SEND = -2,
LOOP_SWAP = -3,
NCCL = -4;
LOOP_SWAP = -3;
};
CompNode() = default;
......
......@@ -630,11 +630,7 @@ void CollectiveComm::get_output_var_shape(const TensorShapeArray& inp_shape,
inp_shape, out_shape);
}
void CollectiveComm::init_output_comp_node() {
mgb_assert(output().size() == 1, "exactly one output expected, got %zu", output().size());
owner_graph()->seq_comp_node_optimizer().register_stream_var(output()[0],
{CompNode::Stream::NCCL, cg::SeqCompNodeOptimizer::StreamPropType::WEAK});
}
void CollectiveComm::init_output_comp_node() {}
void CollectiveComm::init_output_mem_plan(bool dynamic) {
for (size_t i = 0; i < output().size(); i++) {
......
......@@ -269,13 +269,13 @@ TEST(TestOprBasicArith, AddUpdateOtherStream) {
};
std::shared_ptr<HostTensorND> host_val = gen({SIZE});
auto cn_nccl = CompNode::load("gpu0").change_stream(CompNode::Stream::NCCL);
auto cn1 = CompNode::load("gpu0:0").change_stream(1);
auto param = opr::SharedDeviceTensor::make(*graph, *host_val);
param.node()->owner_opr()->node_prop().attribute().priority =
std::numeric_limits<int>::max();
auto copy = opr::Copy::make(param, cn_nccl);
auto copy = opr::Copy::make(param, cn1);
auto add = (copy + 3) * 5;
auto add_update = opr::AddUpdate::make(param, add, {}, {cn_nccl});
auto add_update = opr::AddUpdate::make(param, add, {}, {cn1});
auto callback = opr::CallbackInjector::make(add_update, set_flag);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册