未验证 提交 10cc040d 编写于 作者: J jiangcheng 提交者: GitHub

add fetch op for cinn graph output node of build_cinn_pass (#37172)

上级 444a7358
......@@ -193,9 +193,18 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var, Graph* graph) {
for (auto* old_var : output_vars) {
// create fetch op
OpDesc desc;
desc.SetType("fetch");
desc.SetInput("X", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Output Var Node: " << var->Name();
// link fetch op and fetch var
IR_NODE_LINK_TO(var, op);
for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) {
IR_NODE_LINK_TO(old_op2new_op.at(old_op), var);
......
......@@ -264,7 +264,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// After search, there should has just one cinn subgraph
// feed --> v1 --
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// v2 -- | --> add --> v5 --> relu --> v6 --> fetch
// feed --> v4 --
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
......@@ -272,13 +272,14 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(11));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(12));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
// No-parameter input should has feed op
auto new_v1 = GetNode(subnodes, "var1");
......@@ -292,6 +293,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
// output should has fetch op
auto new_v6 = GetNode(subnodes, "var6");
ASSERT_EQ(new_v6->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v6->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");
ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
}
std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
......@@ -379,7 +387,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
// After search, there should has just one cinn subgraph
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4
// | --> mul --> v3 --> relu --> v4 --> fetch
// v2 --
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
......@@ -387,12 +395,13 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(7));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(8));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
}
std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
......@@ -496,10 +505,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(2));
// subgraph1:
// feed --> v4 --> relu --> v5
// feed --> v4 --> relu --> v5 --> fetch
// subgraph2:
// feed --> v1 --
// | --> mul --> v3
// | --> mul --> v3 --> fetch
// v2 --
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]);
......@@ -511,11 +520,11 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_TRUE(CheckGraphIndependence(subnodes2));
if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6));
}
}
......
......@@ -40,15 +40,14 @@ def set_cinn_flag(val):
class TestResnet50Accuracy(unittest.TestCase):
def reader(self, limit):
for _ in range(limit):
yield np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \
np.random.randint(0, 1000, size=[32]).astype('int64')
yield {'image': np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \
'label': np.random.randint(0, 1000, size=[32]).astype('int64')}
def generate_random_data(self, loop_num=10):
feed = []
data = self.reader(loop_num)
for _ in range(loop_num):
x, y = next(data)
feed.append({'image': x, 'label': y})
feed.append(next(data))
return feed
def build_program(self, main_program, startup_program):
......@@ -57,6 +56,9 @@ class TestResnet50Accuracy(unittest.TestCase):
name='image', shape=[32, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[32], dtype='int64')
# TODO: stop_gradient slower training speed, need fix
image.stop_gradient = False
model = paddle.vision.models.resnet50()
prediction = model(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册