diff --git a/doc/design/register_grad_op.md b/doc/design/register_grad_op.md index 9f1ce4bae7b393cb9f04909e5e4917b8d660771c..8d973eb53178c3e889c845144553a453e11f067c 100644 --- a/doc/design/register_grad_op.md +++ b/doc/design/register_grad_op.md @@ -3,17 +3,17 @@ ## The Problem Posed -Currently, for each C++ operator class definition, there registers a *gradient operator creator* function, which takes a C++ operator instance and returns the corresponding gradient operator instance. +Currently, for each C++ operator class definition, a *gradient operator creator* function is registered, which takes as input a C++ operator instance and returns the corresponding gradient operator instance. -However, we noticed two problems with the current deisgn: +However, we noticed two problems with the current design: -1. As we decided to separate the *compilation* and *execution* phases, we need to change the creator to take an `OpDesc` protobuf message in a `ProgramDesc` and inserts corresponding `OpDesc` messages into the `ProgramDesc` message. +1. As we decided to separate the *compilation* and the *execution* phases, we need to change the creator to take an `OpDesc` protobuf message in a `ProgramDesc` and inserts corresponding `OpDesc` messages into the `ProgramDesc` message. -1. Some operator's gradient computation requires more than one gradient operators. For example, the gradient of *minus* consists of two operators -- an identity operaotr and a scale operator. So we need to make the registration mechanism to support the mapping from an operator to a set of operators for gradient computation. +1. For some operators, the gradient computation can be written in terms of existing operators. For example, the gradient of *minus* operator consists of two operators -- an *identity* operator followed by a *scale* operator. Hence the registration mechanism needs to support mapping from an operator to a set of operators for the gradient computation. ## The Current Implementation -The C++ class `OpInfos` store in a association map which key is the operator type. The `grad_op_type` indicate associated gradient operator type. Operator can create gradient operator by `OpInfo::creator_` of gradient. The pseudo code is +Instances of the C++ class `OpInfo` are stored an associative map whose key is the operator type. The `grad_op_type` indicates the associated gradient operator type. An operator can create the gradient operator by invoking `OpInfo::creator_` of the gradient operator. The pseudo code is as follows ```cpp struct OpInfo { @@ -31,16 +31,16 @@ OperatorBase* CreateGradientOperator(const OperatorBase& op) { ## Proposed Solution -The mapping relationship between an operator and its gradient operators is a function. The interface of that function is: +The mapping relationship between an operator and its gradient operators is a function. The interface of this function is: ```cpp // (OpDesc) --> vector std::function(const OpDescBind&)>; ``` -The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for protobuf message `OpDesc` to manipulate `OpDesc` fast. +The function takes an `OpDescBind` of the forward operator and returns one or many gradient operator descriptions. `OpDescBind` is a C++ wrapper for the protobuf message `OpDesc` for rapid manipulation of `OpDesc`. -The `GradOpDescMaker` will be registered in `OpInfo`, to replace `grad_op_type_` field. The `OpInfo` should be +The `GradOpDescMaker` will be registered in `OpInfo` and will replace the `grad_op_type_` field. The `OpInfo` should look like ```cpp struct OpInfo { @@ -49,7 +49,7 @@ struct OpInfo { }; ``` -The `grad_op_maker_ ` is `nullptr` if the operator does not have associated gradient operators. +The `grad_op_maker_ ` is a `nullptr` if the operator does not have any associated gradient operators. We propose a base class called `GradOpDescMakerBase` to let operator developers generate `Gradient Operators` easily. The public interface of that class is @@ -74,7 +74,7 @@ func = [] (const OpDescBind& fwd_op) { We can write many helper functions since the `GradOpDescMakerBase` is a class now. The basic helper functions get the variables of `Input`, `Output`, `InputGradient` and `OutputGradient` in the forwarding operator. -We should chagne register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`. +We should change register macros at the same time. In the current solution, there is no difference between forwarding operators and backward operators. So `REGISTER_OP` just register one operator. If the `REGISTER_OPERATOR ` contains `OpProtoAndCheckerMaker` and `GradOpDescMaker`, we just list them in the same macro. It can be done by a macro contains `__VA_ARGS__`. The user interface should be