register_grad_op.md 3.9 KB
Newer Older
Y
Update  
Yu Yang 已提交
1
# Design Doc: Gradient Operators Registration
2 3


Y
Update  
Yu Yang 已提交
4 5
## The Problem Posed

Y
Yi Wang 已提交
6
Currently, for each C++ operator class definition, there registers a *gradient operator creator* function, which takes a C++ operator instance and returns the corresponding gradient operator instance.
Y
Update  
Yu Yang 已提交
7

Y
Yi Wang 已提交
8
However, we noticed two problems with the current deisgn:
Y
Update  
Yu Yang 已提交
9

Y
Yi Wang 已提交
10
1. As we decided to separate the *compilation* and *execution* phases, we need to change the creator to take an `OpDesc` protobuf message in a `ProgramDesc` and inserts corresponding `OpDesc` messages into the `ProgramDesc` message.
Y
Update  
Yu Yang 已提交
11

Y
Yi Wang 已提交
12
1. Some operator's gradient computation requires more than one gradient operators.  For example, the gradient of *minus* consists of two operators -- an identity operaotr and a scale operator.  So we need to make the registration mechanism to support the mapping from an operator to a set of operators for gradient computation.
Y
Update  
Yu Yang 已提交
13

Y
Yi Wang 已提交
14 15 16
## The Current Implementation

The C++ class `OpInfos` store in a association map which key is the operator type. The `grad_op_type` indicate associated gradient operator type. Operator can create gradient operator by `OpInfo::creator_` of gradient. The pseudo code is
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31

```cpp
struct OpInfo {
  std::function<OperatorBase*(...)> creator_;
  std::string grad_op_type_;
  ...
};

map<string, OpInfo> OpInfoMap;

OperatorBase* CreateGradientOperator(const OperatorBase& op) {
  return OpInfoMap.at(op.Type()).creator_(...);
}
```

Y
Update  
Yu Yang 已提交
32
## Proposed Solution
33 34 35 36 37

The mapping relationship between an operator and its gradient operators is a function. The interface of that function is:

```cpp
// (OpDesc) --> vector<OpDesc>
Y
Yu Yang 已提交
38
std::function<std::vector<OpDescBind>(const OpDescBind&)>;
39 40
```

Y
Yu Yang 已提交
41
The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast.
42 43 44 45 46

The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be

```cpp
struct OpInfo {
Y
Yu Yang 已提交
47
  std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>  grad_op_maker_;
48 49 50 51 52 53
  ...
};
```

The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.

Y
Yu Yang 已提交
54 55 56 57 58 59
We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is

```cpp
class GradOpDescMakerBase {
public:
  GradOpDescMakerBase(const OpDescBind& );
Y
Yu Yang 已提交
60
  virtual std::vector<std::unique_ptr<OpDescBind>> operator()()const = 0;
Y
Yu Yang 已提交
61 62 63
};
```

Y
Yu Yang 已提交
64
We can convert `GradOpDescMakerBase` to `std::function<std::vector<std::unique_ptr<OpDescBind>>(const OpDescBind&)>` by
Y
Yu Yang 已提交
65 66 67 68 69 70 71 72 73 74 75 76

```cpp
using GradOpMaker = ...;
std::function<std::vector<OpDescBind>(const OpDescBind&)> func;
func = [] (const OpDescBind& fwd_op) {
  GradOpMaker maker(fwd_op);
  return maker();
};
```

We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator.

Y
Yu Yang 已提交
77
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
78 79 80 81

The user interface should be

```cpp
Y
Yu Yang 已提交
82 83
vector<OpDesc> MinusOpGradMaker(OpDesc) {...}
REGISTER_OPERATOR(minus, MinusOp, MinusOpProtoAndCheckerMaker, SumOpGradMaker);
84
// Developers can still manually implement gradient operator.
Y
Yu Yang 已提交
85
REGISTER_OPERATOR(minus_grad, MinusGradOp);
Y
Yu Yang 已提交
86 87 88 89 90
```

The interface of current `REGISTER_OP` macro could not be changed. In `REGISTER_OP`, it will invoke `REGISTER_OPERATOR` two times and generate GradOpDescMaker inside.

```cpp
Y
Yu Yang 已提交
91
REGISTER_OP(minus, MinusOp, MinusOpProtoAndCheckerMaker, minus_grad, MinusGradOp);
92
```