Prune¶
Motivation¶
We want to support running inference, training and checkpointing in one ProgramDesc
. We implement
void Prune(const ProgramDesc* input, ProgramDesc* output)
function, which takes a ProgramDesc
and generate a pruned ProgramDesc
.
Challenge¶
Pruning need to support both variables and operators being evaluation targets. Consider the following different situations.
# Case 1: run foward pass.
cost_np = session.run(target=cost)
# Case 2: run backward passing.
opts_np, _ = session.run(target=[cost, opt])
# Case 3: run checkpointing
_ = session.run(target=checkpoint)
Solution¶
To support evaluation of operators, we add is_target
field in the OpDesc
.
message OpDesc {
required string type = 3;
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
};
To support evaluation of variables, we add fetch_op.
For each variable in the target
, we insert a fetch_op
into the ProgramDesc
with variable
being
fetch_op
‘s input. Then we also set fetch_op
is a target.
Algorithm¶
If an operator needs to be run, it must fall into one of the following cases:
- It is the target.
- It is depended by some other ops, meaning its output is some other op’s input.
The first case can be checked by op_desc.is_traget()
. The second case can be implement as
bool HasDependentVar(const OpDesc& op_desc, const std::set<string>& dependent_vars) {
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
return true;
}
}
}
return false;
}
Then the whole algorithm can be implemented as the following code.