From 3ca3a200ab14454954ba44de3deba5caea229f51 Mon Sep 17 00:00:00 2001 From: "Yang Yang(Tony)" Date: Wed, 18 Oct 2017 19:00:53 -0700 Subject: [PATCH] Prune Design Doc (#4732) * Create prune.md * modification based on comment * remove insertion * rename id to block_id * Update prune.md * formatting --- doc/design/prune.md | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 doc/design/prune.md diff --git a/doc/design/prune.md b/doc/design/prune.md new file mode 100644 index 0000000000..4a5cf10c79 --- /dev/null +++ b/doc/design/prune.md @@ -0,0 +1,63 @@ +# Prune + +## Motivation + +We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement +`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc` +and generate a pruned `ProgramDesc`. + +## Challenge + +Pruning need to support both variables and operators being evaluation targets. Consider the following +different situations. + +```python +# Case 1: run foward pass. +cost_np = session.run(target=cost) +# Case 2: run backward passing. +opts_np, _ = session.run(target=[cost, opt]) +# Case 3: run checkpointing +_ = session.run(target=checkpoint) +``` + +## Solution + +To support evaluation of operators, we add `is_target` field in the `OpDesc`. + +```c++ +message OpDesc { + required string type = 3; + repeated Var inputs = 1; + repeated Var outputs = 2; + repeated Attr attrs = 4; + optional bool is_target = 5 [ default = false ]; +}; +``` + +To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599). +For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being +`fetch_op`'s input. Then we also set `fetch_op` is a target. + +### Algorithm + +If an operator needs to be run, it must fall into one of the following cases: + +1. It is the target. +2. It is depended by some other ops, meaning its output is some other op's input. + +The first case can be checked by `op_desc.is_traget()` . The second case can be implement as + +```c++ +bool HasDependentVar(const OpDesc& op_desc, const std::set& dependent_vars) { + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + return true; + } + } + } + return false; +} +``` + +Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc). -- GitLab