提交 5cfe7365 编写于 作者: M Megvii Engine Team

fix(mgb/core): ensure all VarNodes would be handled in init_ready_event

GitOrigin-RevId: 0b6cb240211f9d4f1acafbecd7fd199f2c3d9153
上级 2d0c9690
......@@ -303,9 +303,8 @@ void SeqCompNodeOptimizerImpl::init_ready_event(
}
opr->input_waiting_spec(std::move(waiting_spec));
auto&& usable_output = opr->usable_output();
for (size_t i = 0; i < usable_output.size(); ++ i) {
var2step[usable_output[i]] = {cur_step, i};
for (size_t i = 0; i < opr->output().size(); ++ i) {
var2step[opr->output(i)] = {cur_step, i};
}
cur_step ++;
}
......
......@@ -1085,6 +1085,55 @@ TEST(TestGraph, DynShapeDepCrossCN) {
ASSERT_EQ(24.f, host_b.ptr<int>()[0]);
}
namespace {
MGB_DEFINE_OPR_CLASS(CustomCopy, cg::SingleCNOperatorNodeBase) // {
std::shared_ptr<DeviceTensorND> m_data;
void scn_do_execute() override {
using namespace std::literals;
std::this_thread::sleep_for(100ms);
m_data->copy_from(input(0)->dev_tensor());
}
void init_output_static_infer_desc() override {
using namespace cg::static_infer;
owner_graph()->static_infer_manager().register_shape_infer(
output(0), ShapeInferDesc::make_const({}));
}
public:
CustomCopy(VarNode* x, std::shared_ptr<DeviceTensorND> dv)
: Super{x->owner_graph(), {dv->comp_node()}, "d2h", {x}},
m_data(dv) {
add_input({x});
using F = VarNode::Flag;
add_output(None)
->add_flag(F::ALLOW_EMPTY_SHAPE)
.add_flag(F::VOLATILE_CONTENT);
}
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomCopy);
}
TEST(TestGraph, DependentOnVolatileContent) {
HostTensorGenerator<> gen;
auto cn0 = CompNode::load("xpu0"),
cn1 = cn0.change_stream(1);
auto host_x = gen({233}, cn0);
auto dev_y = std::make_shared<DeviceTensorND>(cn1);
auto graph = ComputingGraph::make();
auto x = opr::SharedDeviceTensor::make(*graph, *host_x),
y = x.insert_single_output_opr<CustomCopy>(x.node(), dev_y),
x_new = opr::AddUpdate::make(x, x.make_scalar(1));
auto func = graph->compile({{y, {}}, {x_new, {}}});
func->execute().wait();
HostTensorND host_y;
host_y.copy_from(*dev_y).sync();
MGB_ASSERT_TENSOR_EQ(*host_x, host_y);
}
namespace {
void check_wait(SymbolVar dest, SymbolVar dep) {
if (!dep.node()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册