diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index a2df5cfcf93865b10124b8f24a96abcdb07a90c7..0c27d41f33ffb3d76a3e024da21ff74514f386d5 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -17,19 +17,40 @@ limitations under the License. */ #include #include +#include #include #include #include +#include #include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/program_desc.h" + namespace paddle { namespace framework { const char kFeedOpType[] = "feed"; const char kFetchOpType[] = "fetch"; -bool HasDependentVar(const proto::OpDesc& op_desc, - const std::set& dependent_vars) { +bool HasDependentInputVar( + const proto::OpDesc& op_desc, + const std::unordered_set& dependent_vars) { + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} + +bool HasDependentOutputVar( + const proto::OpDesc& op_desc, + const std::unordered_set& dependent_vars) { for (auto& var : op_desc.outputs()) { for (auto& argu : var.arguments()) { if (dependent_vars.count(argu) != 0) { @@ -47,6 +68,14 @@ bool IsTarget(const proto::OpDesc& op_desc) { return false; } +bool HasTrueTarget(const proto::OpDesc& op_desc) { + return op_desc.has_is_target() && op_desc.is_target(); +} + +bool HasFalseTarget(const proto::OpDesc& op_desc) { + return op_desc.has_is_target() && !op_desc.is_target(); +} + int GetSubBlockIndex(const proto::OpDesc& op_desc) { for (auto& attr : op_desc.attrs()) { if (attr.type() == proto::AttrType::BLOCK) { @@ -61,6 +90,24 @@ bool HasSubBlock(const proto::OpDesc& op_desc) { return GetSubBlockIndex(op_desc) > 0; } +void AppendOpInputVarNames(const proto::OpDesc& op_desc, + std::unordered_set* vars_set) { + for (auto& var : op_desc.inputs()) { + for (auto& arg : var.arguments()) { + vars_set->emplace(arg); + } + } +} + +void AppendOpOutputVarNames(const proto::OpDesc& op_desc, + std::unordered_set* vars_set) { + for (auto& var : op_desc.outputs()) { + for (auto& arg : var.arguments()) { + vars_set->emplace(arg); + } + } +} + // block_id is the idx of the current block in the input desc // parent_block_id is the idx of the parent of the current block // in the output desc, -1 means the current block is global block @@ -68,7 +115,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) { // the child block to help pruning void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, int block_id, int parent_block_id, - std::set* dependent_vars, + std::unordered_set* dependent_vars, const std::set feed_var_names) { auto& block = input.blocks(block_id); auto& ops = block.ops(); @@ -91,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, std::vector should_run; for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { auto& op_desc = *op_iter; - if (IsTarget(op_desc) || HasDependentVar(op_desc, *dependent_vars)) { + if (IsTarget(op_desc) || HasDependentOutputVar(op_desc, *dependent_vars)) { // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { for (auto& argu : var.arguments()) { @@ -127,7 +174,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, *op = input.blocks(block_id).ops(i); if (HasSubBlock(*op)) { // create sub_block_dependent_vars here to help prune the sub block - std::set sub_block_dependent_vars; + std::unordered_set sub_block_dependent_vars; for (auto& var : op->inputs()) { for (auto& argu : var.arguments()) { if (feed_var_names.count(argu) == 0) { @@ -188,9 +235,139 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, void Prune(const proto::ProgramDesc& input, const std::set& feed_var_names, proto::ProgramDesc* output) { - std::set dependent_vars; + std::unordered_set dependent_vars; output->clear_blocks(); prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names); } + +void CloneWholeBlock(proto::ProgramDesc* input, proto::ProgramDesc* output, + int block_id, int parent_block_id) { + auto* block_field = output->mutable_blocks(); + *block_field->Add() = input->blocks(block_id); + int output_block_id = output->blocks_size() - 1; + auto* output_block = output->mutable_blocks(output_block_id); + output_block->set_idx(output_block_id); + output_block->set_parent_idx(parent_block_id); +} + +void PruneBackwardImpl(proto::ProgramDesc* input, proto::ProgramDesc* output, + int block_id, int parent_block_id) { + // Step 1. Copy the current input block to output + CloneWholeBlock(input, output, block_id, parent_block_id); + int output_block_id = output->blocks_size() - 1; + auto* output_block = output->mutable_blocks(output_block_id); + + // Step 2. Mark forward ops on main branch + auto* ops = input->mutable_blocks(block_id)->mutable_ops(); + std::unordered_set op_input_vars; + std::unordered_set op_output_vars; + for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) { + auto& op_desc = *op_iter; + if (HasTrueTarget(op_desc) || + HasDependentOutputVar(op_desc, op_input_vars)) { + op_desc.set_is_target(true); + AppendOpInputVarNames(op_desc, &op_input_vars); + AppendOpOutputVarNames(op_desc, &op_output_vars); + } + } + + // Step 3. Mark backward & optimize ops on main branch + std::unordered_set gradop_input_vars; + std::unordered_set gradop_output_vars; + for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { + auto& op_desc = *op_iter; + if (HasFalseTarget(op_desc) || + HasDependentInputVar(op_desc, gradop_output_vars)) { + op_desc.set_is_target(false); + AppendOpInputVarNames(op_desc, &gradop_input_vars); + AppendOpOutputVarNames(op_desc, &gradop_output_vars); + } + } + + // Step 4. Mark ops need to be reserved on sub-branch + for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) { + auto& op_desc = *op_iter; + if (!op_desc.has_is_target()) { + if (HasDependentOutputVar(op_desc, gradop_input_vars)) { + op_desc.set_is_target(false); + AppendOpInputVarNames(op_desc, &gradop_input_vars); + } else { + op_desc.set_is_target(true); + AppendOpInputVarNames(op_desc, &op_input_vars); + AppendOpOutputVarNames(op_desc, &op_output_vars); + } + } + } + + // Step 5. Copy the forward ops to new ProgramDesc + // Note: The proto::ProgramDesc doesn't have interface + // to remove op and var + auto* op_field = output_block->mutable_ops(); + op_field->Clear(); + for (auto op_iter = ops->begin(); op_iter != ops->end(); ++op_iter) { + if (IsTarget(*op_iter)) { + auto* op = op_field->Add(); + *op = *op_iter; + if (HasSubBlock(*op)) { + CloneWholeBlock(input, output, GetSubBlockIndex(*op), output_block_id); + } + } + } + + // Step 6. Copy the forward vars to new ProgramDesc + // construct all var's map before clear + auto* var_field = output_block->mutable_vars(); + std::unordered_map var_map; + for (const auto& var : *var_field) { + var_map[var.name()] = var; + } + std::unordered_set var_names; + var_names.insert(op_input_vars.begin(), op_input_vars.end()); + var_names.insert(op_output_vars.begin(), op_output_vars.end()); + var_field->Clear(); + for (const auto& name : var_names) { + *var_field->Add() = var_map[name]; + } +} + +std::unique_ptr PruneBackward( + const framework::ProgramDesc& origin) { + // Copy original ProgramDesc, origin can't be change + framework::ProgramDesc origin_clone(origin); + + // Step 1. Update loss op's role & set loss op to be target + // The loss op's op_role is (kForward | kLoss) + // The input ProgramDesc should have loss operator. + auto ops = origin_clone.Block(0).AllOps(); + bool has_loss_op = false; + for (auto op : ops) { + int op_role = + boost::get(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); + if (op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss))) { + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + op->SetIsTarget(true); + has_loss_op = true; + } else if (op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + op->SetIsTarget(false); + break; + } + } + PADDLE_ENFORCE_EQ(has_loss_op, true, + "The Program need to be pruned its backward part" + "should have loss operator."); + + // Step 2. Prune backward + proto::ProgramDesc pruned_desc; + pruned_desc.clear_blocks(); + PruneBackwardImpl(origin_clone.Proto(), &pruned_desc, 0, -1); + + // Step 3. Contruct new framework::ProgramDesc + return std::unique_ptr( + new framework::ProgramDesc(pruned_desc)); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/prune.h b/paddle/fluid/framework/prune.h index 3caa6cde4f4867c472d6fe3fe5bd562e583b6420..f710106a263a4d4350007c1580aaf83560faaa7e 100644 --- a/paddle/fluid/framework/prune.h +++ b/paddle/fluid/framework/prune.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include #include #include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -26,5 +28,8 @@ void Prune(const proto::ProgramDesc& input, const std::set& feed_var_names, proto::ProgramDesc* output); +std::unique_ptr PruneBackward( + const framework::ProgramDesc& origin); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 768839ec11bba9e735575b7b6dfe9119fc19fb43..3619f1c9e2a52f7f933cc21e82577d3ee9513a11 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -761,6 +761,9 @@ All parameter, weight, gradient are variables in Paddle. Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc); return new ProgramDesc(pruned_desc); }); + m.def("prune_backward", [](const framework::ProgramDesc &program) { + return PruneBackward(program); + }); m.def("empty_var_name", []() { return std::string(framework::kEmptyVarName); }); m.def("grad_var_suffix", diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 1f6c18f10cbff3a4c356a227a4943bab80c0c919..bd5f45bba93b560044acbd02583ed4cd4c0f96e4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3235,9 +3235,13 @@ class Program(object): """ if for_test: if self._appending_grad_times > 0: - loss_op = self._find_loss_op() - assert loss_op is not None, "The optimized network should have loss operator." - forward_prog = self._prune([], loss_op) + forward_prog = Program() + forward_prog.desc = core.prune_backward(self.desc) + forward_prog.blocks = [ + Block(forward_prog, i) + for i in six.moves.range(forward_prog.desc.num_blocks()) + ] + forward_prog._sync_with_cpp() p = forward_prog._inference_optimize(prune_read_op=False) else: p = self._inference_optimize(prune_read_op=False) @@ -3637,16 +3641,6 @@ class Program(object): for each_var in list(each_block.vars.values()): yield each_var - def _find_loss_op(self): - loss_op = None - op_role_key = core.op_proto_and_checker_maker.kOpRoleAttrName() - forward_loss = int(core.op_proto_and_checker_maker.OpRole.Forward - ) | int(core.op_proto_and_checker_maker.OpRole.Loss) - for op in self.global_block().ops: - if int(op.all_attrs()[op_role_key]) == forward_loss: - loss_op = op - return loss_op - class Parameter(Variable): """ diff --git a/python/paddle/fluid/tests/unittests/test_program_prune_backward.py b/python/paddle/fluid/tests/unittests/test_program_prune_backward.py index 099bbcddd192a2a6dcdbb2a55cd7487962bed32b..ed259b12f038032d8c8c3e7e6c607d1791e80efe 100755 --- a/python/paddle/fluid/tests/unittests/test_program_prune_backward.py +++ b/python/paddle/fluid/tests/unittests/test_program_prune_backward.py @@ -52,6 +52,25 @@ def lstm_net(use_feed): return avg_cost +def simple_fc_net_with_accuracy(use_feed): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + hidden = img + for _ in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='relu', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + accuracy_out = fluid.layers.accuracy(input=prediction, label=label, k=5) + return loss + + class TestProgramPruneBackward(unittest.TestCase): def program_compare(self, program_a, program_b): assert isinstance( @@ -109,6 +128,21 @@ class TestProgramPruneBackward(unittest.TestCase): "label": label}, optimizer=optimizer) + def test_simple_fc_net_with_accuracy(self): + def optimizer(): + optimizer = fluid.optimizer.SGD( + learning_rate=0.001, + regularization=fluid.regularizer.L2Decay(1e-4)) + return optimizer + + with self.program_scope_guard(): + img, label = init_data() + self.check_prune_correctness( + method=simple_fc_net_with_accuracy, + feed_dict={"image": img, + "label": label}, + optimizer=optimizer) + def test_batchnorm_fc(self): def optimizer(): optimizer = fluid.optimizer.SGD(