diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index fd4c0440eb449f19d0e3528b1f1eddc42a88f674..008fb45fb7bcb2f9b3d02376b15d2f88515f86d9 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -55,7 +55,7 @@ message OpDesc { repeated Var inputs = 1; repeated Var outputs = 2; repeated Attr attrs = 4; - required bool is_target = 5 [ default = false ]; + optional bool is_target = 5 [ default = false ]; }; // OpProto describes a C++ framework::OperatorBase derived class. diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index c9a1d7d5cf99059ea2a2766a550f23fc639c8ece..b08e0116b7db6a37ebc45f0976968f0f5c79e950 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -39,6 +39,13 @@ bool HasDependentVar(const OpDesc& op_desc, return false; } +bool IsTarget(const OpDesc& op_desc) { + if (op_desc.has_is_target()) { + return op_desc.is_target(); + } + return false; +} + void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { // TODO(tonyyang-svail): // - will change to use multiple blocks for RNN op and Cond Op @@ -66,7 +73,7 @@ void Prune(const ProgramDesc& input, ProgramDesc& output, int block_id) { for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { auto& op_desc = *op_iter; - if (op_desc.is_target() || HasDependentVar(op_desc, dependent_vars)) { + if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) { // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { for (auto& argu : var.arguments()) {