From 10cc040d194bee47c5e1693715f2e6e673105b71 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Mon, 15 Nov 2021 11:01:11 +0800 Subject: [PATCH] add fetch op for cinn graph output node of build_cinn_pass (#37172) --- .../framework/paddle2cinn/build_cinn_pass.cc | 9 ++++++ .../paddle2cinn/build_cinn_pass_test.cc | 29 ++++++++++++------- .../unittests/test_resnet50_with_cinn.py | 10 ++++--- 3 files changed, 34 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 0cff68c41eb..f280214ad0b 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -193,9 +193,18 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster, const GraphNodeMap& old_op2new_op, const GraphNodeMap& old_var2new_var, Graph* graph) { for (auto* old_var : output_vars) { + // create fetch op + OpDesc desc; + desc.SetType("fetch"); + desc.SetInput("X", {old_var->Name()}); + auto op = graph->CreateOpNode(&desc); + auto* var = old_var2new_var.at(old_var); VLOG(4) << "Add Output Var Node: " << var->Name(); + // link fetch op and fetch var + IR_NODE_LINK_TO(var, op); + for (auto* old_op : old_var->inputs) { if (cluster.count(old_op)) { IR_NODE_LINK_TO(old_op2new_op.at(old_op), var); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 79a27dccb4b..d76a855b122 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -264,7 +264,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { // After search, there should has just one cinn subgraph // feed --> v1 -- // | --> mul --> v3 -- - // v2 -- | --> add --> v5 --> relu --> v6 + // v2 -- | --> add --> v5 --> relu --> v6 --> fetch // feed --> v4 -- auto compilation_keys = GetCompilationKeys(*g); ASSERT_EQ(compilation_keys.size(), static_cast(1)); @@ -272,13 +272,14 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(11)); + ASSERT_EQ(subnodes.size(), static_cast(12)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_EQ(CountNode(subnodes, "feed"), 2); + ASSERT_EQ(CountNode(subnodes, "fetch"), 1); // No-parameter input should has feed op auto new_v1 = GetNode(subnodes, "var1"); @@ -292,6 +293,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ASSERT_TRUE(new_v2->inputs.empty()); ASSERT_EQ(new_v2->outputs.size(), static_cast(1)); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); + + // output should has fetch op + auto new_v6 = GetNode(subnodes, "var6"); + ASSERT_EQ(new_v6->inputs.size(), static_cast(1)); + ASSERT_EQ(new_v6->outputs.size(), static_cast(1)); + ASSERT_EQ(new_v6->inputs[0]->Name(), "relu"); + ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch"); } std::unique_ptr BuildGraphWithOneCinnSubgraph() { @@ -379,7 +387,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { // After search, there should has just one cinn subgraph // feed --> v1 -- - // | --> mul --> v3 --> relu --> v4 + // | --> mul --> v3 --> relu --> v4 --> fetch // v2 -- auto compilation_keys = GetCompilationKeys(*g); ASSERT_EQ(compilation_keys.size(), static_cast(1)); @@ -387,12 +395,13 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(7)); + ASSERT_EQ(subnodes.size(), static_cast(8)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_EQ(CountNode(subnodes, "feed"), 1); + ASSERT_EQ(CountNode(subnodes, "fetch"), 1); } std::unique_ptr BuildGraphWithMultiCinnSubgraph() { @@ -496,10 +505,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ASSERT_EQ(compilation_keys.size(), static_cast(2)); // subgraph1: - // feed --> v4 --> relu --> v5 + // feed --> v4 --> relu --> v5 --> fetch // subgraph2: // feed --> v1 -- - // | --> mul --> v3 + // | --> mul --> v3 --> fetch // v2 -- auto* cinn_compiler = CinnCompiler::GetInstance(); const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]); @@ -511,11 +520,11 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ASSERT_TRUE(CheckGraphIndependence(subnodes2)); if (CheckNodeExisted(subnodes1, "relu")) { - ASSERT_EQ(subnodes1.size(), static_cast(4)); - ASSERT_EQ(subnodes2.size(), static_cast(5)); - } else { - ASSERT_EQ(subnodes2.size(), static_cast(4)); ASSERT_EQ(subnodes1.size(), static_cast(5)); + ASSERT_EQ(subnodes2.size(), static_cast(6)); + } else { + ASSERT_EQ(subnodes2.size(), static_cast(5)); + ASSERT_EQ(subnodes1.size(), static_cast(6)); } } diff --git a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py index 7f865f55878..2f6ca1dfa0c 100644 --- a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py @@ -40,15 +40,14 @@ def set_cinn_flag(val): class TestResnet50Accuracy(unittest.TestCase): def reader(self, limit): for _ in range(limit): - yield np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \ - np.random.randint(0, 1000, size=[32]).astype('int64') + yield {'image': np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \ + 'label': np.random.randint(0, 1000, size=[32]).astype('int64')} def generate_random_data(self, loop_num=10): feed = [] data = self.reader(loop_num) for _ in range(loop_num): - x, y = next(data) - feed.append({'image': x, 'label': y}) + feed.append(next(data)) return feed def build_program(self, main_program, startup_program): @@ -57,6 +56,9 @@ class TestResnet50Accuracy(unittest.TestCase): name='image', shape=[32, 3, 224, 224], dtype='float32') label = paddle.static.data(name='label', shape=[32], dtype='int64') + # TODO: stop_gradient slower training speed, need fix + image.stop_gradient = False + model = paddle.vision.models.resnet50() prediction = model(image) -- GitLab