提交 8543ad64 编写于 作者: F fengjiayi

update backward doc

上级 bb0427ad
# Operator/expression 's Backward # Backward Building
## Motivation ## Motivation
In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass. In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. However, when configuring the model structure, users do not need to definate the backward part. So a mechanism is required by the framework which is able to complete the model's backward part automatically acoording to the given forward part.
## Implementation When implementing a certain `op`, the developer is also asked to implement its backward version, called `grad_op`. A `grad_op` takes gradients of its corresponding `op`'s outputs, and calculate gradients of the `op`'s inputs. During the building of a model's backward part, the framework creates each forward `op`'s `grad_op`, and then string them together in reverse order of forward part. In this way, gradients spread from the end to the beginning of the model, in other word, from the loss to parameters.
In this design doc, we exported only one API for generating the backward pass.
```c++
std::unique_ptr<OperatorBase> Backward(const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
```
The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**. ## Challenges
### Backward Operator Registry The motivation of backward building is obvious. However, to implement it correctly is not so easy. In the **Fluid** design, a deep learning model is described by `Program`, `Block`, `Op` and `Variable`. The `Block` itself can be nested. It means that the `op`s and `variable`s are scattered across different blocks rather than all be gathered in a single graph. Our backward building algorithm shall visit blocks in recursive order and be able to insert `grad_op`s and new created `variable`s into right place.
A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients. ## Usage
| | forward operator | backward operator Although the whole algorithm is comprised of many functions, only one is exposed as API:
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced. ```python
def append_backward(loss, parameter_list=None, no_grad_set=None):
"""
Append backward part to main_program
For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro: Args:
loss(Variable): The variable generated by cost function.
parameter_list(list): Parameters that need to be updated by optimizer.
If None, it means all parameters need to be updated.
```cpp no_grad_set(set): Variables that have no gradients in Block 0.
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
Return:
(list[Variable]): list of (parameters, gradients) pair.
"""
``` ```
`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. By invoking this API, the framework appends backward part for the program where the `loss` is. It takes three arguments. `loss` means the final loss value. It must be a scalar and is usually the output of the loss layer. It is also where the gradient generated and backpropagation starts. `parameter_list` marks all parameters needs updating. If it's `None`, all parameter will be updated by optimizers. `no_grad_set` marks variables without gradient. if all outputs of some `grad_op` are in `no_grad_set`, the `grad_op` will not be run.
`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name. This API will be invoked automatically before optimizer building.
As a result, in most cases users do not need to invoke the API by themselves to append backward part.
### Backward Opeartor Creating ## Implementation
Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp The implementation of backward building algorithm is in `backward.py` file. The whole algorithm can be divided to two independent parts: creating of `grad_op`s and creating of new variables.
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
### Creating `grad_op`s
The creating of `grad_op`s is implemented by:
```python
def _append_backward_ops_(target,
block,
target_block,
no_grad_dict,
grad_to_var):
"""
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
"""
``` ```
The function `BuildGradOp` will sequentially execute following processes: Given a `block`, the function will traverses all `op`s in this block in reverse order, gets corresponding `grad_op` from the C++ core via `core.get_grad_op_desc()`, then append it to `target_block`.
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`. However, some specific `op`(e.g. `while_op`, `if_else_op`) can hold its own sub-block. For these sub-blocks contains `op`s as well, the `grad_op` creating should be recursive.
4. Building backward operator with `inputs`, `outputs` and forward operator's attributes. During the reverse traversal, we check each `op` whether it has an attribute named `sub_block`. If so, it means there is a sub-block and we need to deal with it first. After creating a new block whose father is the one in `op`'s attribute, we invoke `_append_backward_ops_()` recursively, assigning the new block to parameter `target_block` and the one in `op`'s attribute to `block`. The *pseudo-code* shows this process:
### Backward Network Building ```
******* pseudo-code ********
A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing. for op in reversed(block.ops):
if op has an attribute named 'sub_block':
1. Op Get the sub-block(`s_block`) from op's attribute.
Create a new block(`grad_s_block`), whose father is `s_block`.
When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`. Invoke _append_backward_ops_(), with `block=s_block` and `target_block=grad_s_block`
Invoke `core.get_grad_op_desc()` to get op's grad_op.
Insert name correspondings between variables and their gradients of the grad_op to grad_to_var
Assign grad_s_block to grad_op as it's 'sub_block' attribute.
Append grad_op to current target_block.
```
2. NetOp The first invoking of `_append_backward_ops_()` is initiated by `append_backward()`, in which parameters `block` and `target_block` are all assigned with root block(the block with index 0).
In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp. ### Corner Cases of `grad_op` Creating
3. RnnOp In the previous section, we show the regular process of `grad_op` creating. However, in some corner cases, regular algorithm is not enough to get the correct result and appending handling is required. These addtional processes run after the above-mentioned algorithm and do some special adjusts on its output `grad_op`s.
RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet. #### Shared Variables
4. Sharing Variables If a variable is readed by more than one `op` in the forward pass, its gradient is likey to be written by more than one `grad_op`s in the following backward pass. To make the gradient result being the sum of all `grad_op`s' outputs instead of the last running one, we assign each output with a temporary variables, and then add a `sum_op` to add them up.
As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable. For the debug convinience, if the final gradient name is `w@GRAD`, it's corresponding temporary variables will be named as `w@GRAD@RENAME@0`, `w@GRAD@RENAME@1`...
<p align="center"> <figure class="center">
<img src="./images/duplicate_op.png" width="50%" ><br/> <img src="./images/duplicate_op.png" width="45%" >
<img src="images/duplicate_op2.png" width="45%" >
</figure>
​ Figure 1. Sharing variables in operators. See function `_addup_repetitive_outputs_` in `backward.py` for implementation details.
</p> #### No Gradient Variables
​ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting. In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Obviously, when all the outputs of some `grad_op` is marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass.
<p align="center"> But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.
<img src="images/duplicate_op2.png" width="40%" ><br/>
​ Figure 2. Replace sharing variable's gradient with `Add` operator. This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`'s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on fly by scanning all variables' `no_gradient` attribute(True or False).
</p> ### Creating Backward Variables
​ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction. Up to now, we have completed all creating and adjusting jobs of `grad_op`s. However, backward variables have not been created. Now they are only represented by `grad_op`'s input and output arguments. The backward variable creating job will be done by:
5. Part of the Gradient is Zero. ```python
def _append_backward_vars_(block,
start_op_idx,
grad_to_var,
grad_info_map):
"""
Create new variables required by backward pass.
In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator. Args:
block(Block): the block where new variables will be created
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
grad_to_var(dict):
key(str): grad variable name
val(str): corresponding forward variable name
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
"""
```
Given a `block`, this function traverses all the `grad_op`s in it(The argument `start_op_idx` indicates where the grad_op sequence starts.) and creates all the uncreated outputs. The *pseudo-code* shows this process:
Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it. ```
for op in block.ops[start_op_idx : ]:
if op has an attribute named 'sub_block':
Get the sub-block(`s_block`) from op's attribute.
Invoke _append_backward_vars_(), with `block=s_block`
for var_name in op.all_output_names():
if block.has_var_recursive(var_name) or var_name is the name of empty variable:
continue
create a new variable named 'var_name' in block
if grad_to_var.has_key(var_name):
set grad_info_map[grad_to_var[var_name]] as a tuple of (var_name. block)
do op's var type inference
do op's shape inference
```
...@@ -176,6 +176,7 @@ def _append_backward_ops_(target, ...@@ -176,6 +176,7 @@ def _append_backward_ops_(target,
key(str): grad variable name key(str): grad variable name
val(str): corresponding forward variable name val(str): corresponding forward variable name
""" """
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = [] grad_op_descs = []
program = block.program program = block.program
for op in reversed(block.ops): for op in reversed(block.ops):
...@@ -188,6 +189,7 @@ def _append_backward_ops_(target, ...@@ -188,6 +189,7 @@ def _append_backward_ops_(target,
no_grad_dict, grad_to_var, callback) no_grad_dict, grad_to_var, callback)
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list) op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
...@@ -254,18 +256,18 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): ...@@ -254,18 +256,18 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
def append_backward(loss, parameter_list=None, no_grad_set=None): def append_backward(loss, parameter_list=None, no_grad_set=None):
""" """
Create and add gradient Operators in BlockDesc to compute Append backward part to main_program
gradients of `loss` for parameters in parameter_list
Args:
:param loss: an variable generated by cost function. loss(Variable): The variable generated by cost function.
:type loss: Variable parameter_list(list): Parameters that need to be updated by optimizer.
:param no_grad_dict: variable that should not create gradient If None, it means all parameters need to be updated.
:type no_grad_dict: set no_grad_set(set): Variables that have no gradients in Block 0.
:param parameter_list: parameters that need to compute gradient and If None, the set will be generated inside the function and
update to optimize the lost. contains all variables with `step_gradient=True` from all blocks.
:type: list
:return: list of (parameters, gradients) pair. Return:
:rtype: list[Variable] (list[Variable]): list of (parameters, gradients) pair.
""" """
assert isinstance(loss, framework.Variable) assert isinstance(loss, framework.Variable)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册