未验证 提交 10cc040d 编写于 作者: J jiangcheng 提交者: GitHub

add fetch op for cinn graph output node of build_cinn_pass (#37172)

上级 444a7358
...@@ -193,9 +193,18 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster, ...@@ -193,9 +193,18 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op, const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var, Graph* graph) { const GraphNodeMap& old_var2new_var, Graph* graph) {
for (auto* old_var : output_vars) { for (auto* old_var : output_vars) {
// create fetch op
OpDesc desc;
desc.SetType("fetch");
desc.SetInput("X", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);
auto* var = old_var2new_var.at(old_var); auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Output Var Node: " << var->Name(); VLOG(4) << "Add Output Var Node: " << var->Name();
// link fetch op and fetch var
IR_NODE_LINK_TO(var, op);
for (auto* old_op : old_var->inputs) { for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) { if (cluster.count(old_op)) {
IR_NODE_LINK_TO(old_op2new_op.at(old_op), var); IR_NODE_LINK_TO(old_op2new_op.at(old_op), var);
......
...@@ -264,7 +264,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -264,7 +264,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// After search, there should has just one cinn subgraph // After search, there should has just one cinn subgraph
// feed --> v1 -- // feed --> v1 --
// | --> mul --> v3 -- // | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6 // v2 -- | --> add --> v5 --> relu --> v6 --> fetch
// feed --> v4 -- // feed --> v4 --
auto compilation_keys = GetCompilationKeys(*g); auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1)); ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
...@@ -272,13 +272,14 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -272,13 +272,14 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes(); const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(11)); ASSERT_EQ(subnodes.size(), static_cast<size_t>(12));
ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
// No-parameter input should has feed op // No-parameter input should has feed op
auto new_v1 = GetNode(subnodes, "var1"); auto new_v1 = GetNode(subnodes, "var1");
...@@ -292,6 +293,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -292,6 +293,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_TRUE(new_v2->inputs.empty()); ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
// output should has fetch op
auto new_v6 = GetNode(subnodes, "var6");
ASSERT_EQ(new_v6->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v6->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v6->inputs[0]->Name(), "relu");
ASSERT_EQ(new_v6->outputs[0]->Name(), "fetch");
} }
std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() { std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
...@@ -379,7 +387,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { ...@@ -379,7 +387,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
// After search, there should has just one cinn subgraph // After search, there should has just one cinn subgraph
// feed --> v1 -- // feed --> v1 --
// | --> mul --> v3 --> relu --> v4 // | --> mul --> v3 --> relu --> v4 --> fetch
// v2 -- // v2 --
auto compilation_keys = GetCompilationKeys(*g); auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1)); ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
...@@ -387,12 +395,13 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { ...@@ -387,12 +395,13 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes(); const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(7)); ASSERT_EQ(subnodes.size(), static_cast<size_t>(8));
ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1); ASSERT_EQ(CountNode(subnodes, "feed"), 1);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
} }
std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() { std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
...@@ -496,10 +505,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ...@@ -496,10 +505,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(2)); ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(2));
// subgraph1: // subgraph1:
// feed --> v4 --> relu --> v5 // feed --> v4 --> relu --> v5 --> fetch
// subgraph2: // subgraph2:
// feed --> v1 -- // feed --> v1 --
// | --> mul --> v3 // | --> mul --> v3 --> fetch
// v2 -- // v2 --
auto* cinn_compiler = CinnCompiler::GetInstance(); auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]);
...@@ -511,11 +520,11 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ...@@ -511,11 +520,11 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
ASSERT_TRUE(CheckGraphIndependence(subnodes2)); ASSERT_TRUE(CheckGraphIndependence(subnodes2));
if (CheckNodeExisted(subnodes1, "relu")) { if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5)); ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6));
} }
} }
......
...@@ -40,15 +40,14 @@ def set_cinn_flag(val): ...@@ -40,15 +40,14 @@ def set_cinn_flag(val):
class TestResnet50Accuracy(unittest.TestCase): class TestResnet50Accuracy(unittest.TestCase):
def reader(self, limit): def reader(self, limit):
for _ in range(limit): for _ in range(limit):
yield np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \ yield {'image': np.random.randint(0, 256, size=[32, 3, 224, 224]).astype('float32'), \
np.random.randint(0, 1000, size=[32]).astype('int64') 'label': np.random.randint(0, 1000, size=[32]).astype('int64')}
def generate_random_data(self, loop_num=10): def generate_random_data(self, loop_num=10):
feed = [] feed = []
data = self.reader(loop_num) data = self.reader(loop_num)
for _ in range(loop_num): for _ in range(loop_num):
x, y = next(data) feed.append(next(data))
feed.append({'image': x, 'label': y})
return feed return feed
def build_program(self, main_program, startup_program): def build_program(self, main_program, startup_program):
...@@ -57,6 +56,9 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -57,6 +56,9 @@ class TestResnet50Accuracy(unittest.TestCase):
name='image', shape=[32, 3, 224, 224], dtype='float32') name='image', shape=[32, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[32], dtype='int64') label = paddle.static.data(name='label', shape=[32], dtype='int64')
# TODO: stop_gradient slower training speed, need fix
image.stop_gradient = False
model = paddle.vision.models.resnet50() model = paddle.vision.models.resnet50()
prediction = model(image) prediction = model(image)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册