prune.md 1.9 KB
Newer Older
Y
Yang Yang(Tony) 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
# Prune

## Motivation

We want to support running inference, training and checkpointing in one `ProgramDesc`. We implement 
`void Prune(const ProgramDesc* input, ProgramDesc* output)` function, which takes a `ProgramDesc`
and generate a pruned `ProgramDesc`.

## Challenge

Pruning need to support both variables and operators being evaluation targets. Consider the following
different situations.

```python
# Case 1: run foward pass.
cost_np = session.run(target=cost)
# Case 2: run backward passing.
opts_np, _ = session.run(target=[cost, opt])
# Case 3: run checkpointing
_ = session.run(target=checkpoint)
```

## Solution

To support evaluation of operators, we add `is_target` field in the `OpDesc`.

```c++
message OpDesc {
  required string type = 3;
  repeated Var inputs = 1;
  repeated Var outputs = 2;
  repeated Attr attrs = 4;
  optional bool is_target = 5 [ default = false ];
};
```

To support evaluation of variables, we add [fetch_op](https://github.com/PaddlePaddle/Paddle/pull/4599).
For each variable in the `target`, we insert a `fetch_op` into the `ProgramDesc` with `variable` being
`fetch_op`'s input. Then we also set `fetch_op` is a target.

### Algorithm

If an operator needs to be run, it must fall into one of the following cases:

1. It is the target.
2. It is depended by some other ops, meaning its output is some other op's input.

The first case can be checked by `op_desc.is_traget()` . The second case can be implement as

```c++
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
  for (auto& var : op_desc.outputs()) {
    for (auto& argu : var.arguments()) {
      if (dependent_vars.count(argu) != 0) {
        return true;
      }
    }
  }
  return false;
}
```

Then the whole algorithm can be implemented as the following [code](https://github.com/tonyyang-svail/Paddle/blob/prune_impl/paddle/framework/prune.cc).