diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc index ab08b851d3dacd7e00f4170110bba09960f69bf8..790fa169244198c41b42d81623504f268f2c91c8 100644 --- a/paddle/framework/prune_test.cc +++ b/paddle/framework/prune_test.cc @@ -161,3 +161,21 @@ TEST(Prune, multi_output_op) { Prune(*pdesc, pruned, 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); } + +TEST(Prune, multi_target) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned, 0); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); +}