From 12713e92e4a28fa7de762d9be76125efdace7ae2 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Fri, 29 Sep 2017 13:49:18 -0700 Subject: [PATCH] Design doc of compile time register gradient operators --- doc/design/register_grad_op.md | 63 ++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 doc/design/register_grad_op.md diff --git a/doc/design/register_grad_op.md b/doc/design/register_grad_op.md new file mode 100644 index 000000000..1c961a615 --- /dev/null +++ b/doc/design/register_grad_op.md @@ -0,0 +1,63 @@ +# Design Doc: Register Gradient Operator + +## Problem + +Since we separate users program in two stages, compile time and runtime, we should record and look up the mapping relationship between an operator and its gradient operators when compile. However, we register this relationship in runtime by these `OpInfo` fields. + +```cpp +struct OpInfo { + std::function creator_; + std::string grad_op_type_; + ... +}; +``` + +OpInfos store in a association map which key is the operator type. The `grad_op_type` indicate associated gradient operator type. Operator can create gradient operator by `OpInfo::creator_` of gradient. The pseudo code is + +```cpp +map OpInfoMap; + +OperatorBase* CreateGradientOperator(const OperatorBase& op) { + return OpInfoMap.at(op.Type()).creator_(...); +} +``` + +At the same time, an operator's gradient operator could be composed of many forward operators. For example, the gradient operator of `minus_op` could consist of an `identity` operator and a `scale` operator. To compose a gradient operator by forwarding operators could: 1) Reuse forwarding operator; 2) Calculate second derivative, third derivative, etc. + +We use `NetOp` to represent a composed operator since the `NetOp` is `vector`. However, `NetOp` is also a runtime concept. We should provide a mechanism to compose operators as a gradient operator. + +In conclusion, the problem that we want to resolve in this design doc is to register the mapping relationship between the forward operator and its gradient operators during compile time. + + +## Solution + +The mapping relationship between an operator and its gradient operators is a function. The interface of that function is: + +```cpp +// (OpDesc) --> vector +using GradOpDescMaker = std::function(const OpDesc&)>; +``` + +The function take a `OpDesc` of the forward operator and return one or many gradient operator descriptions. + +The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be + +```cpp +struct OpInfo { + GradOpDescMaker grad_op_maker_; + ... +}; +``` + +The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators. + +We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OP` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker Æ’`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`. + +The user interface should be + +```cpp +vector SumOpGradMakerÆ’(OpDesc) {...} +REGISTER_OP(sum, SumOp, SumOpProtoAndCheckerMaker, SumOpGradMaker); +// Developers can still manually implement gradient operator. +REGISTER_OP(sum_grad, SumGradOp); +``` -- GitLab