register_grad_op.md 2.6 KB
Newer Older
Y
Update  
Yu Yang 已提交
1
# Design Doc: Gradient Operators Registration
2 3


Y
Update  
Yu Yang 已提交
4 5 6 7 8 9 10 11 12 13 14
## The Problem Posed

In our current operator registration mechanism, for each operator, the programmer should register a *gradient operator creator* function, which takes a C++ operator instance, and returns the corresponding gradient instance.

However, as we decided to separate the *compilation* and *execution* of DL models, we need to reshape the creator to take a protobuf `OpDesc` message, and returns a corresponding message.

More than that, the new registration mechanism need to support the fact that an operators' gradient computation might be a composition of operators.

## Current Implementation

OpInfos store in a association map which key is the operator type. The `grad_op_type` indicate associated gradient operator type. Operator can create gradient operator by `OpInfo::creator_` of gradient. The pseudo code is
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29

```cpp
struct OpInfo {
  std::function<OperatorBase*(...)> creator_;
  std::string grad_op_type_;
  ...
};

map<string, OpInfo> OpInfoMap;

OperatorBase* CreateGradientOperator(const OperatorBase& op) {
  return OpInfoMap.at(op.Type()).creator_(...);
}
```

Y
Update  
Yu Yang 已提交
30
## Proposed Solution
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

The mapping relationship between an operator and its gradient operators is a function. The interface of that function is:

```cpp
// (OpDesc) --> vector<OpDesc>
using GradOpDescMaker = std::function<std::vector<OpDesc>(const OpDesc&)>;
```

The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions.

The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be

```cpp
struct OpInfo {
  GradOpDescMaker grad_op_maker_;
  ...
};
```

The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators.

Y
Yu Yang 已提交
52
We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`.
53 54 55 56

The user interface should be

```cpp
Y
Yu Yang 已提交
57 58
vector<OpDesc> MinusOpGradMaker(OpDesc) {...}
REGISTER_OPERATOR(minus, MinusOp, MinusOpProtoAndCheckerMaker, SumOpGradMaker);
59
// Developers can still manually implement gradient operator.
Y
Yu Yang 已提交
60
REGISTER_OPERATOR(minus_grad, MinusGradOp);
Y
Yu Yang 已提交
61 62 63 64 65
```

The interface of current `REGISTER_OP` macro could not be changed. In `REGISTER_OP`, it will invoke `REGISTER_OPERATOR` two times and generate GradOpDescMaker inside.

```cpp
Y
Yu Yang 已提交
66
REGISTER_OP(minus, MinusOp, MinusOpProtoAndCheckerMaker, minus_grad, MinusGradOp);
67
```