backward.md 3.9 KB
Newer Older
F
fengjiayi 已提交
1
# Operator/expression 's Backward
D
dongzhihong 已提交
2

F
fengjiayi 已提交
3
## Motivation
D
dongzhihong 已提交
4 5

In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass.
F
fengjiayi 已提交
6 7 8
 
## Backward Operator Registry

F
test  
fengjiayi 已提交
9
A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients. In most cases, there is a one-to-one correspondence between forward and backward operators. We use registry mechanism to save these correspondences, which is quite similar with operator registry itself.
F
fengjiayi 已提交
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25

For example, we have got a `add_two_op`, and is registered by the following code:

```cpp
REGISTER_OP(add_two, AddTwoOp, AddTwoOpMaker);
```

`add_two` is the operator's type. `AddTwoOp` and `AddTwoOpMaker` are the operator class and the operator maker class respectively.

Assume that we have also got the backward operator of `add_two_op`, which calculating the gradients of `add_two_op`'s inputs. Then we register it by the following way:

```cpp
REGISTER_GRADIENT_OP(add_two, add_two_grad, AddTwoGradOp);
```

`add_two_grad` is the type of backward operator, and `AddTwoGradOp` is its class name.
D
dongzhihong 已提交
26

F
fengjiayi 已提交
27
## Backward Opeartor Creating
D
dongzhihong 已提交
28

F
fengjiayi 已提交
29
### Usage
D
dongzhihong 已提交
30

F
fengjiayi 已提交
31
Given a certain forward operator, we can get its corresponding backward opeartor by calling:
D
dongzhihong 已提交
32

F
fengjiayi 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
``` 

The function `BuildGradOp` will sequentially execute following processes:

1. Getting the `type_` of given forward operator, and then creating the corresponding backward operator.

2. Copying all the attributes of forward operator expect `input_format` and `output_format`(if it has), for their elements differ between forward and backward operators.

3. Copying forward operator's `inputs_` and `outputs_` to backward operator's `inputs_`. And adding forward inputs' gradient variables into backward `output_`, adding forward outputs' gradient variables into backward `input_`.

4. Building backward operator's `input_format`, `output_format` (if necessary) and `in_out_idxs_` according to its `inputs_` and `outputs_` just created.

## Backward Network Building
D
dongzhihong 已提交
48

F
fengjiayi 已提交
49
A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and put them together.
D
dongzhihong 已提交
50

F
fengjiayi 已提交
51
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network. 
D
dongzhihong 已提交
52 53 54

given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.

D
dongzhihong 已提交
55 56 57
1. Op 

   when the input forward network is a Op, return its gradient Operator Immediately.
D
dongzhihong 已提交
58 59 60

2. NetOp 

D
dongzhihong 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
   when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.

   **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.  

   <p align="center">
   <img src="./images/duplicate_op.png" width="70%" ><br/>

   1. shared variable in two operators. 

   </p>

   Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links. 

   <p align="center">
   <img src="images/duplicate_op2.png" width="90%" ><br/>
D
dongzhihong 已提交
76

D
dongzhihong 已提交
77
   2. replace shared variable gradient with `Add` Operator
D
dongzhihong 已提交
78

D
dongzhihong 已提交
79
   </p>
D
dongzhihong 已提交
80 81 82



D
dongzhihong 已提交
83
​	Then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.