diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index ddb9ed7ae061c510f93d64a1d928b850bd9a5a61..284541f199da27fd3071bfbc54426b5faca10c9f 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -43,7 +43,7 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int id) { // TODO(tonyyang-svail): // - will change to use multiple blocks for RNN op and Cond Op - auto& block = input.blocks(0); + auto& block = input.blocks(id); auto& ops = block.ops(); bool expect_feed = true; @@ -67,13 +67,6 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int id) { auto& op_desc = *op_iter; if (op_desc.is_target() || HasDependentVar(op_desc, dependent_vars)) { - // erase its output to the dependency graph - for (auto& var : op_desc.outputs()) { - for (auto& argu : var.arguments()) { - dependent_vars.erase(argu); - } - } - // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { for (auto& argu : var.arguments()) { diff --git a/paddle/framework/prune_test.cc b/paddle/framework/prune_test.cc index b66db945282cb015ac032d72b400ce9cb5a4bc06..ab08b851d3dacd7e00f4170110bba09960f69bf8 100644 --- a/paddle/framework/prune_test.cc +++ b/paddle/framework/prune_test.cc @@ -28,105 +28,24 @@ namespace framework { using DeviceContext = platform::DeviceContext; -class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { +class OneOneOpMaker : public OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) + OneOneOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "Input X of Add"); - AddInput("b", "Bias of Add"); - AddOutput("Out", "Out of Add"); - AddComment("Add Op"); + AddInput("input", "input"); + AddOutput("output", "output"); + AddComment("Op has one input and one output"); } }; -class RowWiseAddGradMaker : public SingleGradOpDescMaker { +class TwoOneOpMaker : public OpProtoAndCheckerMaker { public: - using SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - auto grad_op = new OpDescBind(); - grad_op->SetInput(GradVarName("Out"), OutputGrad("Out")); - grad_op->SetOutput(GradVarName("X"), InputGrad("X")); - grad_op->SetOutput(GradVarName("b"), InputGrad("b")); - grad_op->SetType("rowwise_add_grad"); - return std::unique_ptr(grad_op); - } -}; - -class MulOpMaker : public OpProtoAndCheckerMaker { - public: - MulOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "A"); - AddInput("Y", "B"); - AddOutput("Out", "Out"); - AddAttr("x_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); - AddAttr("y_num_col_dims", "").SetDefault(1).EqualGreaterThan(1); - AddComment("Mul"); - } -}; - -class SigmoidOpMaker : public OpProtoAndCheckerMaker { - public: - SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) + TwoOneOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "X"); - AddOutput("Out", "Y"); - AddComment("Sigmoid"); - } -}; - -class NoGradOpMaker : public OpProtoAndCheckerMaker { - public: - NoGradOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "X input"); - AddOutput("Out", "Y output"); - AddComment("NoGradOp, same input output. no Grad"); - } -}; - -class ManyOutputOpMaker : public OpProtoAndCheckerMaker { - public: - ManyOutputOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("x", "x"); - AddOutput("y", "y"); - AddOutput("z", "z"); - AddComment(""); - } -}; - -class FillZeroOpMaker : public OpProtoAndCheckerMaker { - public: - FillZeroOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x"); - AddOutput("Y", "out"); - AddComment(""); - } -}; - -class SumOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "the input tensors of sum operator.").AsDuplicable(); - AddOutput("Out", "the output tensor of sum operator."); - AddComment(""); - } -}; - -class MultInOutOpMaker : public OpProtoAndCheckerMaker { - public: - MultInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "x"); - AddInput("H", "h"); - AddOutput("Y", "y"); - AddOutput("Z", "z"); - AddComment(""); + AddInput("input_1", "input_1"); + AddInput("input_2", "input_2"); + AddOutput("output", "output"); + AddComment("Op has two inputs and one output"); } }; @@ -135,18 +54,8 @@ class MultInOutOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; -using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker, - f::RowWiseAddGradMaker); -REGISTER_OPERATOR(rowwise_add_grad, f::NOP); -REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); -REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); -REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); -REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); -REGISTER_OP(sum, f::NOP, f::SumOpMaker, sum_grad, f::NOP); -REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, - f::NOP); -REGISTER_OP(mult_in_out, f::NOP, f::MultInOutOpMaker, mult_in_out_grad, f::NOP); +REGISTER_OP_WITHOUT_GRADIENT(one_one, f::NOP, f::OneOneOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(two_one, f::NOP, f::TwoOneOpMaker); void AddOp(const std::string &type, const f::VariableNameMap &inputs, const f::VariableNameMap &outputs, f::AttributeMap attrs, @@ -184,7 +93,7 @@ TEST(Prune, one_operator) { f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); f::BlockDescBind *block = program.Block(0); - AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {}, block); + AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc pruned; @@ -197,4 +106,58 @@ TEST(Prune, one_operator) { PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); } -TEST(Prune, simple_optimize) {} +TEST(Prune, forward) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + + AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"d"}}}, {}, block); + AddOp("one_one", {{"input", {"d"}}}, {{"output", {"e"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + + for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { + f::ProgramDesc pruned; + pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); + Prune(*pdesc, pruned, 0); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); + } +} + +TEST(Prune, multi_input_op) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + + AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block); + AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"a2"}}}, {{"output", {"b2"}}}, {}, block); + AddOp("three_one", {{"input", {"b0", "b1", "b2"}}}, {{"output", {"c"}}}, {}, + block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned, 0); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); +} + +TEST(Prune, multi_output_op) { + f::ProgramDesc *program_desc = GetNewProgramDesc(); + f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::BlockDescBind *block = program.Block(0); + + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); + AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); + AddOp("one_one", {{"input", {"c"}}}, {{"output", {"c1"}}}, {}, block); + + f::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); + + f::ProgramDesc pruned; + Prune(*pdesc, pruned, 0); + PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); +}