提交 fc8b501b 编写于 作者: M Megvii Engine Team

refactor(mgb/core): refactor cpu compnode so that default cpu has no ability to record

GitOrigin-RevId: 7de4771476e87d3ed11cffbad4c02741591ef9a2
上级 b7176069
......@@ -54,7 +54,9 @@ namespace mgb {
void add_callback(Task&& task) override;
};
class CompNodeImpl;
class CompNodeBaseImpl;
class CompNodeNoRecorderImpl;
class CompNodeRecorderImpl;
static void foreach(thin_function<void(CompNode)> callback);
static void finalize();
......
......@@ -100,6 +100,26 @@ void run_comp_seq_rec_basic_level2(CompNode cn) {
MGB_ASSERT_TENSOR_NEAR(expect, host_z, 1e-3) << "iter " << iter;
}
ASSERT_EQ(executed.size(), 2u);
//! test default_cpu with record2
{
HostTensorND hz;
graph = ComputingGraph::make();
x = opr::Host2DeviceCopy::make(*graph, host_x);
y = opr::Host2DeviceCopy::make(*graph, host_y);
z = opr::ConvBias::make(x, y, param);
z = opr::GetVarShape::make(z);
graph->options().comp_node_seq_record_level = 2;
graph->options().var_sanity_check_first_run = false;
auto func = graph->compile({make_callback_copy(z, hz, true)});
ComputingGraph::assert_destroy(graph);
func->execute();
ASSERT_TRUE(hz.comp_node() == cn);
ASSERT_EQ(hz.ptr<int>()[0], 3);
ASSERT_EQ(hz.ptr<int>()[1], 6);
ASSERT_EQ(hz.ptr<int>()[2], 8);
ASSERT_EQ(hz.ptr<int>()[3], 6);
}
}
void run_comp_seq_rec_dyn_elemwise(CompNode cn, bool fake_first) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册