diff --git a/doc/design/backward.md b/doc/design/backward.md new file mode 100644 index 0000000000000000000000000000000000000000..35f03692bb052e0a04db18d28f6f8d901215a553 --- /dev/null +++ b/doc/design/backward.md @@ -0,0 +1,156 @@ +# Backward Building + +## Motivation + +In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. However, when configuring the model structure, users do not need to define the backward part. So a mechanism is required by the framework which can complete the model's backward part automatically according to the given forward part. + +When implementing a specific `op`, the developer is also asked to implement its backward version, called `grad_op`. A `grad_op` takes gradients of its corresponding `op`'s outputs, and calculate gradients of the `op`'s inputs. During the building of a model's backward part, the framework creates each forward `op`'s `grad_op`, and then string them together in reverse order of forwarding part. In this way, gradients spread from the end to the beginning of the model, in another word, from the loss to parameters. + +## Challenges + +The motivation of backward building is apparent. However, implementation it correctly is not so easy. In the **Fluid** design, a deep learning model is described by `Program`, `Block`, `Op` and `Variable`. The `Block` itself can be nested. It means that the `op`s and `variable`s are scattered across different blocks rather than all be gathered in a single graph. Our backward building algorithm shall visit blocks in recursive order and be able to insert `grad_op`s and new created `variable`s into the right place. + +## Usage + +Although the whole algorithm is comprised of many functions, only one is exposed as API: + +```python +def append_backward(loss, parameter_list=None, no_grad_set=None): + """ + Append backward part to main_program + + Args: + loss(Variable): The variable generated by the cost function. + parameter_list(list): Parameters that need to be updated by optimizers. + If None, it means all parameters need to be updated. + + no_grad_set(set): Variables that have no gradients in Block 0. + If None, the set will be generated inside the function and + contains all variables with `step_gradient=True` from all blocks. + + Return: + (list[Variable]): list of (parameters, gradients) pair. + """ +``` + +By invoking this API, the framework appends backward part of the program where the `loss` is. It takes three arguments. `loss` means the final loss value. It must be a scalar and is usually the output of the loss layer. It is also where the gradient generated and backpropagation starts. `parameter_list` marks all parameters needs updating. If it's `None`, all parameter will be updated by optimizers. `no_grad_set` marks variables without gradient. if all outputs of some `grad_op` are in `no_grad_set`, the `grad_op` will not be run. + +This API will be invoked automatically before optimizer building. +As a result, in most cases, users do not need to invoke the API by themselves to append backward part. + +## Implementation + +The implementation of backward building algorithm is in `backward.py` file. The whole algorithm can be divided into two independent parts: creating `grad_op`s and creating new variables. + +### Creating `grad_op`s + +The creating of `grad_op`s is implemented by: + +```python +def _append_backward_ops_(target, + block, + target_block, + no_grad_dict, + grad_to_var): + """ + Create all grad ops, and insert them into given block + + Args: + target(Variable): the target variable of forward pass + block(Block): the block where forward ops are + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of varibale names. These varibales have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + """ +``` + +Given a `block`, the function will traverses all `op`s in this block in reverse order, gets corresponding `grad_op` from the C++ core via `core.get_grad_op_desc()`, then append it to `target_block`. + +However, some specific `op`(e.g. `while_op`, `if_else_op`) can hold its own sub-block. For these sub-blocks contains `op`s as well, the `grad_op` creating should be recursive. + +During the reverse traversal, we check each `op` whether it has an attribute named `sub_block`. If so, it means there is a sub-block and we need to deal with it first. After creating a new block whose father is the one in `op`'s attribute, we invoke `_append_backward_ops_()` recursively, assigning the new block to parameter `target_block` and the one in `op`'s attribute to `block`. The *pseudo-code* shows this process: + +``` +******* pseudo-code ******** +for op in reversed(block.ops): + if op has an attribute named 'sub_block': + Get the sub-block(`s_block`) from op's attribute. + Create a new block(`grad_s_block`), whose father is `s_block`. + Invoke _append_backward_ops_(), with `block=s_block` and `target_block=grad_s_block` + + Invoke `core.get_grad_op_desc()` to get op's grad_op. + Insert name correspondings between variables and their gradients of the grad_op to grad_to_var + Assign grad_s_block to grad_op as it's 'sub_block' attribute. + Append grad_op to current target_block. +``` + +The first invoking of `_append_backward_ops_()` is initiated by `append_backward()`, in which parameters `block` and `target_block` are all assigned with root block(the block with index 0). + +### Corner Cases of `grad_op` Creating + +In the previous section, we show the regular process of `grad_op` creating. However, in some corner cases, the conventional algorithm is not enough to get the correct result and appending handling is required. These additional processes run after the algorithm mentioned above and do some special adjusts on its output `grad_op`s. + +#### Shared Variables + +If a variable is read by more than one `op` in the forward pass, its gradient is likely to be written by more than one `grad_op`s in the next backward pass. To make the gradient result being the sum of all `grad_op`s' outputs instead of the last running one, we assign each output with a temporary variable and then add a `sum_op` to add them up. + +For the debug convenience, if the final gradient name is `w@GRAD`, it's corresponding temporary variables will be named as `w@GRAD@RENAME@0`, `w@GRAD@RENAME@1`... + +See function `_addup_repetitive_outputs_` in `backward.py` for implementation details. + +#### No Gradient Variables + +In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass. + +But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros. + +This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False). + +### Creating Backward Variables + +Up to now, we have completed all creating and adjusting jobs of `grad_op`s. However, backward variables have not been created. Now they are only represented by `grad_op`'s input and output arguments. The backward variable creating job will be done by: + +```python +def _append_backward_vars_(block, + start_op_idx, + grad_to_var, + grad_info_map): + """ + Create new variables required by backward pass. + + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + """ +``` + +Given a `block`, this function traverses all the `grad_op`s in it(The argument `start_op_idx` indicates where the grad_op sequence starts.) and creates all the uncreated outputs. The *pseudo-code* shows this process: + +``` +for op in block.ops[start_op_idx : ]: + + if op has an attribute named 'sub_block': + Get the sub-block(`s_block`) from op's attribute. + Invoke _append_backward_vars_(), with `block=s_block` + + for var_name in op.all_output_names(): + if block.has_var_recursive(var_name) or var_name is the name of empty variable: + continue + create a new variable named 'var_name' in block + if grad_to_var.has_key(var_name): + set grad_info_map[grad_to_var[var_name]] as a tuple of (var_name. block) + + do op's var type inference + do op's shape inference +``` diff --git a/paddle/framework/images/duplicate_op.graffle b/doc/design/images/duplicate_op.graffle similarity index 100% rename from paddle/framework/images/duplicate_op.graffle rename to doc/design/images/duplicate_op.graffle diff --git a/paddle/framework/images/duplicate_op.png b/doc/design/images/duplicate_op.png similarity index 100% rename from paddle/framework/images/duplicate_op.png rename to doc/design/images/duplicate_op.png diff --git a/paddle/framework/images/duplicate_op2.graffle b/doc/design/images/duplicate_op2.graffle similarity index 100% rename from paddle/framework/images/duplicate_op2.graffle rename to doc/design/images/duplicate_op2.graffle diff --git a/paddle/framework/images/duplicate_op2.png b/doc/design/images/duplicate_op2.png similarity index 100% rename from paddle/framework/images/duplicate_op2.png rename to doc/design/images/duplicate_op2.png diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md deleted file mode 100644 index ac60be572419b62f4beb644ff192d413c35e19bb..0000000000000000000000000000000000000000 --- a/paddle/framework/backward.md +++ /dev/null @@ -1,100 +0,0 @@ -# Operator/expression 's Backward - -## Motivation - -In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass. - -## Implementation - -In this design doc, we exported only one API for generating the backward pass. - -```c++ -std::unique_ptr Backward(const OperatorBase& forwardOp, - const std::unordered_set& no_grad_vars); -``` - -The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**. - -### Backward Operator Registry - -A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients. - -| | forward operator | backward operator -| ---------------------- | ---------------- |------------------------- | -| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | -| **Operator::outputs_** | Outputs | InputGradients | - - In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced. - -For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro: - -```cpp -REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad); -``` - -`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively. - -`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name. - -### Backward Opeartor Creating - -Given a certain forward operator, we can get its corresponding backward operator by calling: - -```cpp -OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op); -``` - -The function `BuildGradOp` will sequentially execute following processes: - -1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`. - -2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing. - -3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`. - -4. Building backward operator with `inputs`, `outputs` and forward operator's attributes. - -### Backward Network Building - -A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing. - -1. Op - - When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`. - -2. NetOp - - In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp. - -3. RnnOp - - RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet. - -4. Sharing Variables - - As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable. - -

-
- -​ Figure 1. Sharing variables in operators. - -

- -​ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting. - -

-
- -​ Figure 2. Replace sharing variable's gradient with `Add` operator. - -

- -​ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction. - -5. Part of the Gradient is Zero. - - In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator. - - -Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it. diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0c47a71f73a8e46af8f64964af1719516dcf6333..bfcc70b31defd378fba0e505b1522471ca285ac4 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -259,6 +259,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(conv_transpose_op DEPS vol2col) op_library(gru_op DEPS sequence2batch gru_compute) op_library(recurrent_op DEPS executor) +op_library(cos_sim_op DEPS cos_sim_functor) # FIXME(typhoonzero): save/load depends lodtensor serialization functions op_library(save_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor) diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index 052c793a01907abdc7784d1290f43543ae81bdb1..c83318a272302a474c37ce86365201acf56b9cad 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -105,48 +105,18 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - for (size_t i = 0; i < grad_rows.size(); i++) { - size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]); - for (int64_t j = 0; j < grad_width; j++) { - grad_merge_data[grad_merge_i * grad_width + j] += - grad_data[i * grad_width + j]; - } - } + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto& merge_rows = grad_merge.rows(); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); diff --git a/paddle/operators/adagrad_op.cu b/paddle/operators/adagrad_op.cu index 75bc7affd6c78beb783e01682b4538f2c259df26..4e579387924a5b0499f29609bc6b1322030a3c0d 100644 --- a/paddle/operators/adagrad_op.cu +++ b/paddle/operators/adagrad_op.cu @@ -78,62 +78,30 @@ struct SparseAdagradFunctor { const framework::Tensor& learning_rate, T epsilon, framework::Tensor* moment, framework::Tensor* param) { // 1. g_m.rows = set(g.rows) - auto grad_rows = grad.rows(); - std::set row_set(grad_rows.begin(), grad_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); - auto grad_width = grad.value().dims()[1]; - std::unique_ptr grad_merge{ - new framework::SelectedRows()}; - grad_merge->set_rows(merge_rows); - grad_merge->set_height(grad.height()); - grad_merge->mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), grad_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, grad_merge->mutable_value(), 0.0); - - auto* grad_merge_data = grad_merge->mutable_value()->data(); - auto* grad_data = grad.value().data(); - - const int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid1(1, grad_rows.size()); - - MergeGradKernel< - T, 256><<(context) - .stream()>>>(grad_data, grad.rows().data(), - grad_merge_data, grad_merge->rows().data(), - grad_merge->rows().size(), grad_width); - + math::scatter::MergeAdd merge_func; + auto grad_merge = merge_func(context, grad); + auto* grad_merge_data = grad_merge.mutable_value()->template data(); + auto& merge_rows = grad_merge.rows(); // 2. m += g_m * g_m - std::unique_ptr grad_square{ - new framework::SelectedRows()}; - grad_square->set_rows(grad_merge->rows()); - grad_square->set_height(grad_merge->height()); - grad_square->mutable_value()->mutable_data(grad_merge->value().dims(), - context.GetPlace()); - auto gs = - framework::EigenVector::Flatten(*(grad_square->mutable_value())); - auto gm = framework::EigenVector::Flatten(grad_merge->value()); - gs.device(*context.eigen_device()) = gm * gm; + math::scatter::Mul sqare_func; + auto grad_square = sqare_func(context, grad_merge, grad_merge); math::SelectedRowsAddToTensor functor; - functor(context, *grad_square, moment); + functor(context, grad_square, moment); // 3. update parameter auto* lr = learning_rate.data(); auto* param_data = param->data(); auto* moment_data = moment->data(); + const int block_size = 256; + dim3 threads(block_size, 1); dim3 grid2(1, merge_rows.size()); SparseAdagradFunctorKernel< T, 256><<(context) - .stream()>>>(grad_merge_data, grad_merge->rows().data(), + .stream()>>>(grad_merge_data, grad_merge.rows().data(), lr, param_data, moment_data, grad_width, epsilon); } diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index c4e2c8bb88ec9c74bd782570c10fb217178c8e48..9cc34bdded780e61e8700eb4fa4a295c84fb48bc 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -16,11 +16,14 @@ limitations under the License. */ #include // for sqrt in CPU and CUDA #include "paddle/framework/op_registry.h" #include "paddle/operators/detail/safe_ref.h" +#include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/for_range.h" namespace paddle { namespace operators { +namespace scatter = paddle::operators::math::scatter; + template struct AdamFunctor { T beta1_; @@ -79,6 +82,69 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel) {} + + inline HOSTDEVICE void operator()(size_t i) const { + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + for (int64_t j = 0; j < row_numel_; ++j) { + T g = grad_[i * row_numel_ + j]; + T mom1 = moment1_[rows_[i] * row_numel_ + j]; + T mom2 = moment2_[rows_[i] * row_numel_ + j]; + T lr = *lr_; + T p = param_[rows_[i] * row_numel_ + j]; + + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + moment1_out_[rows_[i] * row_numel_ + j] = mom1; + moment2_out_[rows_[i] * row_numel_ + j] = mom2; + param_out_[rows_[i] * row_numel_ + j] = p; + } // for col id + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel { T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); auto& param = Ref(ctx.Input("Param"), "Must set Param"); - auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + // auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + auto* grad_var = ctx.InputVar("Grad"); auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1"); auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2"); auto& lr = @@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel { auto& mom2_out = Ref(ctx.Output("Moment2Out"), "Must set Moment1Out"); - AdamFunctor functor(beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), - mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad.template data(), - param.template data(), - param_out.template mutable_data(ctx.GetPlace())); - platform::ForRange for_range( - static_cast(ctx.device_context()), param.numel()); - for_range(functor); + if (grad_var->IsType()) { + auto& grad = Ref(ctx.Input("Grad"), "Must set Grad"); + AdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad.template data(), + param.template data(), + param_out.template mutable_data(ctx.GetPlace())); + platform::ForRange for_range( + static_cast(ctx.device_context()), + param.numel()); + for_range(functor); + } else if (grad_var->IsType()) { + auto& grad = + Ref(ctx.Input("Grad"), "Must set Grad"); + // merge duplicated rows if any. + scatter::MergeAdd merge_func; + auto grad_merge = + merge_func(ctx.template device_context(), grad); + auto& grad_tensor = grad_merge.value(); + const T* grad_data = grad_tensor.template data(); + auto* rows = grad_merge.rows().data(); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); + platform::ForRange for_range( + static_cast(ctx.device_context()), + grad_merge.rows().size()); + for_range(functor); + } else { + PADDLE_THROW("Variable type not supported by adam_op"); + } } }; diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index e2b6282c0913e8ad16f8e3f6c3054f9567822d15..eadcca55f9bfc3e59f329df8ff419ad4c5a29007 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -13,19 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/platform/for_range.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; -template -using EigenVector = framework::EigenVector; template class CosSimKernel : public framework::OpKernel { @@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel { out_x_norm->mutable_data(context.GetPlace()); out_y_norm->mutable_data(context.GetPlace()); - // convert Tensor to Eigen Tensor int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenVector::Flatten(*out_z); - auto x_norm = EigenVector::Flatten(*out_x_norm); - auto y_norm = EigenVector::Flatten(*out_y_norm); - // compute - auto& place = - *context.template device_context().eigen_device(); - auto row_along = Eigen::array({{1}}); - x_norm.device(place) = x.square().sum(row_along).sqrt(); - y_norm.device(place) = y.square().sum(row_along).sqrt(); + int cols = framework::product(in_x->dims()) / rows_x; + if (rows_x == rows_y) { - auto xy = (x * y).sum(Eigen::array({{1}})); - z.device(place) = xy / x_norm / y_norm; + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } else { - Eigen::DSizes bcast(rows_x, 1); - auto xy = (x * y.broadcast(bcast)).sum(row_along); - z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + math::CosSimFunctor functor( + in_x->data(), in_y->data(), out_x_norm->data(), + out_y_norm->data(), out_z->data(), cols); + platform::ForRange for_range( + static_cast(context.device_context()), rows_x); + for_range(functor); } } }; @@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel { auto* out_grad_y = context.Output(framework::GradVarName("Y")); auto* in_grad_z = context.Input(framework::GradVarName("Out")); - // convert Tensor to Eigen Tensor - auto x = EigenMatrix::Reshape(*in_x, 1); - auto y = EigenMatrix::Reshape(*in_y, 1); - auto z = EigenMatrix::Reshape(*in_z, 1); - auto x_norm = EigenMatrix::Reshape(*in_x_norm, 1); - auto y_norm = EigenMatrix::Reshape(*in_y_norm, 1); - auto dz = EigenMatrix::Reshape(*in_grad_z, 1); - // compute gradident int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - Eigen::DSizes bcast_cols(1, cols); - auto z_bcast = z.broadcast(bcast_cols); - auto dz_bcast = dz.broadcast(bcast_cols); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); - auto& place = - *context.template device_context().eigen_device(); + if (rows_x == rows_y) { - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { - out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::Reshape(*out_grad_y, 1); - auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; - dy.device(place) = dz_bcast * grad; + math::CosSimGradFunctor functor( + in_y_norm->data(), in_x_norm->data(), in_y->data(), + in_x->data(), in_z->data(), in_grad_z->data(), + out_grad_y->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } } else { - Eigen::DSizes bcast_rows(rows_x, 1); - Eigen::DSizes bcast_rows_cols(rows_x, cols); - auto y_bcast = y.broadcast(bcast_rows); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); - auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows)) - .eval() - .broadcast(bcast_cols); - // compute dx if (out_grad_x) { - out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::Reshape(*out_grad_x, 1); - auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; - dx.device(place) = dz_bcast * grad; + math::CosSimDxFunctor functor( + in_x_norm->data(), in_y_norm->data(), in_x->data(), + in_y->data(), in_z->data(), in_grad_z->data(), + out_grad_x->mutable_data(context.GetPlace()), cols); + platform::ForRange for_range( + static_cast(context.device_context()), + rows_x); + for_range(functor); } - // compute dy if (out_grad_y) { out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenVector::Flatten(*out_grad_y); - auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; - dy.device(place) = (dz_bcast * grad).sum(Eigen::array({{0}})); + math::SetConstant set_zero; + auto& dev_ctx = context.template device_context(); + set_zero(dev_ctx, out_grad_y, static_cast(0)); + + math::CosSimDyFunctor functor; + functor(dev_ctx, in_x_norm->data(), in_y_norm->data(), + in_x->data(), in_y->data(), in_z->data(), + in_grad_z->data(), static_cast(rows_x), + static_cast(cols), out_grad_y->data()); } } } diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h index c6228864d7ec042ff99e4521d1d707ba091e8ed5..b1957fb9ce6add8628cb206abf2c569d3f615c85 100644 --- a/paddle/operators/gru_op.h +++ b/paddle/operators/gru_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/sequence2batch.h" @@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel { } int frame_size = hidden_dims[1]; - math::hl_gru_value gru_value; + math::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); @@ -89,6 +90,10 @@ class GRUKernel : public framework::OpKernel { } auto batch_starts = batch_gate->lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); for (size_t n = 0; n < num_batch; n++) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -101,9 +106,8 @@ class GRUKernel : public framework::OpKernel { gru_value.gate_value = gate_t.data(); gru_value.reset_output_value = reset_hidden_prev_t.data(); math::GRUUnitFunctor::compute( - dev_ctx, gru_value, frame_size, cur_batch_size, - math::ActiveType(context.Attr("activation")), - math::ActiveType(context.Attr("gate_activation"))); + dev_ctx, gru_value, frame_size, cur_batch_size, active_node, + active_gate); gru_value.prev_out_value = gru_value.output_value; } @@ -170,12 +174,12 @@ class GRUGradKernel : public framework::OpKernel { batch_hidden_grad.set_lod(batch_hidden->lod()); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); - math::hl_gru_value gru_value; + math::GRUMetaValue gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); - math::hl_gru_grad gru_grad; + math::GRUMetaGrad gru_grad; if (weight_grad) { gru_grad.gate_weight_grad = weight_grad->mutable_data(context.GetPlace()); @@ -189,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel { auto batch_starts = batch_hidden_grad.lod()[0]; size_t num_batch = batch_starts.size() - 1; + auto active_node = math::detail::GetActivationType( + context.Attr("activation")); + auto active_gate = math::detail::GetActivationType( + context.Attr("gate_activation")); for (int n = static_cast(num_batch) - 1; n >= 0; n--) { int bstart = static_cast(batch_starts[n]); int bend = static_cast(batch_starts[n + 1]); @@ -219,9 +227,8 @@ class GRUGradKernel : public framework::OpKernel { } math::GRUUnitGradFunctor::compute( - dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, - math::ActiveType(context.Attr("activation")), - math::ActiveType(context.Attr("gate_activation"))); + dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node, + active_gate); } if (input_grad) { input_grad->mutable_data(context.GetPlace()); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index b97faec4ed687c1cf8d746cdf615e86fd79ca921..7ebcfb9ab9f30e3b0f13d3646a59d008335b232d 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -16,6 +16,7 @@ if(WITH_GPU) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -30,6 +31,7 @@ else() cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/cos_sim_functor.cc b/paddle/operators/math/cos_sim_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..6af9f0fcd967b4e8e9e338c155d5a8ee2866fbfa --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cc @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimDyFunctor { + void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + for (size_t row_id = 0; row_id < rows; ++row_id) { + auto xy_norm_prod = x_norm[row_id] * y_norm[0]; + auto dz_data = dz[row_id]; + auto z_data = z[row_id]; + auto* x_data = x + cols * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + auto y_norm_square = y_norm[0] * y_norm[0]; + auto reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + } + } + } +}; + +template struct CosSimDyFunctor; +template struct CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.cu b/paddle/operators/math/cos_sim_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..6eb0a4ea4c5b86f84c93b97615255adf55e9e042 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/cos_sim_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x, + const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) { + int grid_size = blockDim.x * gridDim.x; + T y_norm_data = y_norm[0]; + for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows; + row_id += grid_size) { + T xy_norm_prod = x_norm[row_id] * y_norm_data; + T dz_data = dz[row_id]; + T z_data = z[row_id]; + const T* x_data = x + cols * row_id; + T reciprocal_xy_norm_prod = 1 / xy_norm_prod; + + T y_norm_square = y_norm_data * y_norm_data; + T reciprocal_y_norm_square = 1 / y_norm_square; + for (size_t i = 0; i < cols; ++i) { + T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod - + z_data * y[i] * reciprocal_y_norm_square); + platform::CudaAtomicAdd(dy + i, dy_data); + } + } +} + +template +struct CosSimDyFunctor { + void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm, + const T* y_norm, const T* x, const T* y, const T* z, + const T* dz, const size_t rows, const size_t cols, + T* dy) const { + const int block_size = 512; + dim3 threads(block_size, 1); + dim3 grid(1, (rows + block_size - 1) / block_size); + CosSimDyKernel<<>>( + x_norm, y_norm, x, y, z, dz, rows, cols, dy); + } +}; + +template struct CosSimDyFunctor; +template struct CosSimDyFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/cos_sim_functor.h b/paddle/operators/math/cos_sim_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..aae8ab5b7a937c016e8a45e34b22aba7a1df3066 --- /dev/null +++ b/paddle/operators/math/cos_sim_functor.h @@ -0,0 +1,166 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CosSimFunctor { + CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto* x = x_ + cols_ * row_id; + T xx = 0, xy = 0, yy = 0; + if (same_row) { + auto* y = y_ + cols_ * row_id; + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + y_norm_[row_id] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } else { // This can be wrote in a better way. + T tep_x, tep_y; + for (size_t i = 0; i < cols_; ++i) { + tep_x = x[i]; + tep_y = y_[i]; + xx += tep_x * tep_x; + yy += tep_y * tep_y; + xy += tep_x * tep_y; + } + xx = sqrt(xx); + yy = sqrt(yy); + if (row_id == 0) y_norm_[0] = yy; + x_norm_[row_id] = xx; + z_[row_id] = xy / (xx * yy); + } + } + + T* x_norm_; + T* y_norm_; + const T* x_; + const T* y_; + T* z_; + const size_t cols_; +}; + +template +struct CosSimGradFunctor { + CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + + auto* dx = dx_ + cols_ * row_id; + auto* x = x_ + cols_ * row_id; + auto* y = y_ + cols_ * row_id; + + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto reciprocal_x_norm_square = 1 / x_norm_square; + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDxFunctor { + CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y, + const T* z, const T* dz, T* dx, int cols) + : x_norm_(x_norm), + y_norm_(y_norm), + x_(x), + y_(y), + z_(z), + dz_(dz), + dx_(dx), + cols_(static_cast(cols)) {} + + inline HOSTDEVICE void operator()(size_t row_id) const { + auto xy_norm_prod = x_norm_[row_id] * y_norm_[0]; + auto dz = dz_[row_id]; + auto z = z_[row_id]; + auto* x = x_ + cols_ * row_id; + auto reciprocal_xy_norm_prod = 1 / xy_norm_prod; + auto x_norm_square = x_norm_[row_id] * x_norm_[row_id]; + auto* dx = dx_ + cols_ * row_id; + auto reciprocal_x_norm_square = 1 / x_norm_square; + + for (size_t i = 0; i < cols_; ++i) { + dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod - + z * x[i] * reciprocal_x_norm_square); + } + } + const T* x_norm_; + const T* y_norm_; + const T* x_; + const T* y_; + const T* z_; + const T* dz_; + T* dx_; + const size_t cols_; +}; + +template +struct CosSimDyFunctor { + void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm, + const T* x, const T* y, const T* z, const T* dz, + const size_t rows, const size_t cols, T* dy) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/gru_cpu_kernel.h b/paddle/operators/math/detail/gru_cpu_kernel.h index 4c67dec9cbeb48f400f79f5ed7ba3c939fa2540c..a61b232f4275d93cae1d9a71d49a779216c3555b 100644 --- a/paddle/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/operators/math/detail/gru_cpu_kernel.h @@ -28,7 +28,7 @@ template void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; @@ -56,7 +56,7 @@ template void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -83,7 +83,7 @@ template void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { #ifdef __AVX__ __m256 r_value_update_gate; __m256 r_value_reset_gate; @@ -113,7 +113,7 @@ template void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { #ifdef __AVX__ __m256 r_value_update_gate; __m256 r_value_frame_state; @@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, template inline void forward_reset_output(OpResetOutput op_reset_output, - hl_gru_value value, int frame_size, - int batch_size, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( @@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output, template inline void forward_final_output(OpFinalOutput op_final_output, - hl_gru_value value, int frame_size, - int batch_size, - activation_mode_t active_node) { + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { for (int b = 0; b < batch_size; b++) { if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, @@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { T r_update_gate_value; T r_update_gate_grad; T r_frame_state_value; @@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { T r_update_gate_value; T r_update_gate_grad; T r_reset_gate_value; @@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, - activation_mode_t active_node) { + ActivationType active_node) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, - activation_mode_t active_gate) { + ActivationType active_gate) { #ifdef __AVX__ __m256 r_update_gate_value; __m256 r_update_gate_grad; @@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, template inline void backward_state_grad(OpStateGrad op_state_grad, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { for (int b = 0; b < batch_size; b++) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_state_grad( @@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad, template inline void backward_reset_grad(OpResetGrad op_reset_grad, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { hl_avx_gru_backward_reset_grad( diff --git a/paddle/operators/math/detail/gru_gpu_kernel.h b/paddle/operators/math/detail/gru_gpu_kernel.h index d2edcb7f258b387530799b967fc0fff61acc5b83..1783d46096858c874b27ce75760342082835b180 100644 --- a/paddle/operators/math/detail/gru_gpu_kernel.h +++ b/paddle/operators/math/detail/gru_gpu_kernel.h @@ -19,8 +19,6 @@ limitations under the License. */ #include "paddle/platform/cuda_helper.h" #include "paddle/platform/device_context.h" -#include - namespace paddle { namespace operators { namespace math { @@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, T *gate_value, T *reset_output_value, T *prev_output_value, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; @@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, T *gate_value, T *prev_output_value, T *output_value, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *output_grad, int frame_size, int batch_size, - activation_mode_t active_node) { + ActivationType active_node) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; @@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, T *gate_grad, T *prev_out_value, T *prev_out_grad, T *reset_output_grad, int frame_size, int batch_size, - activation_mode_t active_gate) { + ActivationType active_gate) { const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; if (frame_idx >= frame_size) return; int batch_idx = 0; diff --git a/paddle/operators/math/detail/gru_kernel.h b/paddle/operators/math/detail/gru_kernel.h index acd84be01db9ddaf06d165d8be353b253f324dd2..4d8245cb5d03b33edbda5d8350be02b4fa87ab95 100644 --- a/paddle/operators/math/detail/gru_kernel.h +++ b/paddle/operators/math/detail/gru_kernel.h @@ -30,7 +30,7 @@ class gru_resetOutput { public: HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, T &prev_out, T &value_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { value_update_gate = activation(value_update_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate); value_reset_output = prev_out * value_reset_gate; @@ -43,7 +43,7 @@ class gru_resetOutput { HOSTDEVICE void operator()(__m256 &value_update_gate, __m256 &value_reset_gate, __m256 &prev_out, __m256 &value_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { value_update_gate = activation(value_update_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate); value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); @@ -57,7 +57,7 @@ class gru_finalOutput { public: HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, T &prev_out, T &value_output, - activation_mode_t act_input) { + ActivationType act_input) { value_frame_state = activation(value_frame_state, act_input); value_output = prev_out - (value_update_gate * prev_out) + (value_update_gate * value_frame_state); @@ -69,8 +69,7 @@ class gru_finalOutput { static const bool avx = true; HOSTDEVICE void operator()(__m256 &value_update_gate, __m256 &value_frame_state, __m256 &prev_out, - __m256 &value_output, - activation_mode_t act_input) { + __m256 &value_output, ActivationType act_input) { value_frame_state = activation(value_frame_state, act_input); value_output = _mm256_add_ps( _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), @@ -89,7 +88,7 @@ class gru_stateGrad { HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, T &value_frame_state, T &grad_frame_state, T &value_prev_out, T &grad_prev_out, - T &grad_output, activation_mode_t act_input) { + T &grad_output, ActivationType act_input) { grad_update_gate = (grad_output * value_frame_state); grad_update_gate -= (grad_output * value_prev_out); grad_prev_out -= (grad_output * value_update_gate); @@ -107,7 +106,7 @@ class gru_stateGrad { __m256 &value_frame_state, __m256 &grad_frame_state, __m256 &value_prev_out, __m256 &grad_prev_out, __m256 &grad_output, - activation_mode_t act_input) { + ActivationType act_input) { grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); grad_update_gate = _mm256_sub_ps( grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); @@ -128,7 +127,7 @@ class gru_resetGrad { HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, T &value_reset_gate, T &grad_reset_gate, T &value_prev_out, T &grad_prev_out, - T &grad_reset_output, activation_mode_t act_gate) { + T &grad_reset_output, ActivationType act_gate) { grad_reset_gate = (grad_reset_output * value_prev_out); grad_prev_out += (grad_reset_output * value_reset_gate); grad_update_gate = @@ -144,7 +143,7 @@ class gru_resetGrad { __m256 &grad_update_gate, __m256 &value_reset_gate, __m256 &grad_reset_gate, __m256 &value_prev_out, __m256 &grad_prev_out, __m256 &grad_reset_output, - activation_mode_t act_gate) { + ActivationType act_gate) { grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); grad_prev_out = _mm256_add_ps( grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); diff --git a/paddle/operators/math/gru_compute.cc b/paddle/operators/math/gru_compute.cc index d570c68cd458914c8951c4ce50a02e3c5b1acab0..101ab859624869bf34d171cd42d46d0c5bdac29c 100644 --- a/paddle/operators/math/gru_compute.cc +++ b/paddle/operators/math/gru_compute.cc @@ -21,9 +21,9 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::CPUDeviceContext &context, - hl_gru_value value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { #ifndef __NVCC__ if (value.prev_out_value) { math::gemm( @@ -51,10 +51,10 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::CPUDeviceContext &context, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node); diff --git a/paddle/operators/math/gru_compute.cu b/paddle/operators/math/gru_compute.cu index dd518cd1e4bea52f0d463150114feed3ceea0ccb..d5a0e630ea0eadea990988c3170395c842a91900 100644 --- a/paddle/operators/math/gru_compute.cu +++ b/paddle/operators/math/gru_compute.cu @@ -21,9 +21,9 @@ namespace math { template struct GRUUnitFunctor { static void compute(const platform::CUDADeviceContext &context, - hl_gru_value value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + GRUMetaValue value, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; @@ -88,10 +88,10 @@ struct GRUUnitFunctor { template struct GRUUnitGradFunctor { static void compute(const platform::CUDADeviceContext &context, - hl_gru_value value, hl_gru_grad grad, + GRUMetaValue value, GRUMetaGrad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate) { + const detail::ActivationType active_node, + const detail::ActivationType active_gate) { auto stream = context.stream(); dim3 threads; dim3 grid; diff --git a/paddle/operators/math/gru_compute.h b/paddle/operators/math/gru_compute.h index ca1343cb2c5c1eb8da92c2f06b25902c1c2fe8b3..bf69147b506661692a6d71823043cd3506ea8b5d 100644 --- a/paddle/operators/math/gru_compute.h +++ b/paddle/operators/math/gru_compute.h @@ -11,7 +11,7 @@ limitations under the License. */ #pragma once -#include "paddle/operators/math/lstm_compute.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -19,9 +19,8 @@ namespace paddle { namespace operators { namespace math { -// TODO(guosheng): refine code style in gru_compute template -struct hl_gru_value { +struct GRUMetaValue { T *gate_weight; T *state_weight; T *gate_value; @@ -31,7 +30,7 @@ struct hl_gru_value { }; template -struct hl_gru_grad { +struct GRUMetaGrad { T *gate_weight_grad; T *state_weight_grad; T *gate_grad; @@ -42,18 +41,18 @@ struct hl_gru_grad { template struct GRUUnitFunctor { - static void compute(const DeviceContext &context, hl_gru_value value, + static void compute(const DeviceContext &context, GRUMetaValue value, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate); + const detail::ActivationType active_node, + const detail::ActivationType active_gate); }; template struct GRUUnitGradFunctor { - static void compute(const DeviceContext &context, hl_gru_value value, - hl_gru_grad grad, int frame_size, int batch_size, - activation_mode_t active_node, - activation_mode_t active_gate); + static void compute(const DeviceContext &context, GRUMetaValue value, + GRUMetaGrad grad, int frame_size, int batch_size, + const detail::ActivationType active_node, + const detail::ActivationType active_gate); }; } // namespace math diff --git a/paddle/operators/math/lstm_compute.h b/paddle/operators/math/lstm_compute.h index 954762f92286fe13bd2c08ec03c3ac96bb663cca..e1ad6b64d201ef99d83eaa2a821356008dcc9c8e 100644 --- a/paddle/operators/math/lstm_compute.h +++ b/paddle/operators/math/lstm_compute.h @@ -22,14 +22,6 @@ namespace paddle { namespace operators { namespace math { -typedef enum { - HL_ACTIVATION_SIGMOID = 0, - HL_ACTIVATION_RELU = 1, - HL_ACTIVATION_TANH = 2, - HL_ACTIVATION_LINEAR = 3, - HL_ACTIVATION_END -} activation_mode_t; - template struct LstmMetaValue { T *gate_value; @@ -54,20 +46,6 @@ struct LstmMetaGrad { T *check_og_grad; }; -inline activation_mode_t ActiveType(const std::string &type) { - if (type == "sigmoid") { - return HL_ACTIVATION_SIGMOID; - } else if (type == "relu") { - return HL_ACTIVATION_RELU; - } else if (type == "tanh") { - return HL_ACTIVATION_TANH; - } else if (type == "linear" || type == "identity" || type == "") { - return HL_ACTIVATION_LINEAR; - } else { - PADDLE_THROW("Do not support activation type."); - } -} - template class LstmUnitFunctor { public: diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index ab758d1e7fd8ab361948b28e8cb735b9a742a339..8a1ebb58c26578f076bf243adfbd51d10c682b99 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/selected_rows_functor.h" +#include + #include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" namespace paddle { namespace operators { @@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = FindPos(merge_rows, input_rows[i]); + for (int64_t j = 0; j < input_width; j++) { + out_data[out_i * input_width + j] += input_data[i * input_width + j]; + } + } + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +struct UpdateToTensor { + void operator()(const platform::CPUDeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + auto in1_height = input1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* input2_data = input2->data(); + + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::ADD: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUB: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] -= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::SUBBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] - + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + case ScatterOps::MUL: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] *= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIV: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] /= + in1_data[i * in1_row_numel + j]; + break; + case ScatterOps::DIVBY: + INLINE_FOR2(in1_rows.size(), in1_row_numel) + input2_data[in1_rows[i] * in1_row_numel + j] = + in1_data[i * in1_row_numel + j] / + input2_data[in1_rows[i] * in1_row_numel + j]; + break; + } + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 9fddd97a36f7fdb6628d6eeb192cb216fdae3e5b..0ee456f9bc61436bd0f2f8ef20dd1654e7e56d56 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include + #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/selected_rows_functor.h" #include "paddle/platform/cuda_helper.h" @@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; + +namespace scatter { + +template +__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, + T* out, const int64_t* out_rows, + size_t out_rows_size, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t out_idx; + + if (tid == 0) { + for (size_t i = 0; i < out_rows_size; i++) { + if (input_rows[ty] == out_rows[i]) { + out_idx = i; + } + } + } + + __syncthreads(); + + input += ty * row_numel; + out += out_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(out + index, input[index]); + } +} + +template +struct MergeAdd { + framework::SelectedRows operator()(const platform::CUDADeviceContext& context, + const framework::SelectedRows& input) { + framework::SelectedRows out; + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + + out.set_rows(merge_rows); + out.set_height(input.height()); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + auto* input_data = input.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, input_rows.size()); + + MergeAddKernel< + T, 256><<(context) + .stream()>>>(input_data, input.rows().data(), out_data, + out.rows().data(), out.rows().size(), + input_width); + return out; + } +}; + +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; + +template +__global__ void UpdateToTensorKernel(const T* selected_rows, + const int64_t* rows, const ScatterOps& op, + T* tensor_out, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + // FIXME(typhoonzero): use macro fix the below messy code. + switch (op) { + case ScatterOps::ASSIGN: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index]; + } + break; + case ScatterOps::ADD: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] += selected_rows[index]; + } + break; + case ScatterOps::SUB: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] -= selected_rows[index]; + } + break; + case ScatterOps::SUBBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] - tensor_out[index]; + } + break; + case ScatterOps::MUL: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] *= selected_rows[index]; + } + break; + case ScatterOps::DIV: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] /= selected_rows[index]; + } + break; + case ScatterOps::DIVBY: + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = selected_rows[index] / tensor_out[index]; + } + break; + } +} + +template +struct UpdateToTensor { + void operator()(const platform::CUDADeviceContext& context, + const ScatterOps& op, const framework::SelectedRows& input1, + framework::Tensor* input2) { + // NOTE: Use SelectedRowsAddToTensor for better performance + // no additional MergeAdd called. + MergeAdd merge_func; + auto merged_in1 = merge_func(context, input1); + + auto in1_height = merged_in1.height(); + auto in2_dims = input2->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + + auto& in1_value = merged_in1.value(); + auto& in1_rows = merged_in1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); + + auto* in1_data = in1_value.template data(); + auto* in2_data = input2->data(); + + dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1); + dim3 grid(1, in1_rows.size()); + UpdateToTensorKernel<<< + grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op, + in2_data, in1_row_numel); + } +}; +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 1149075abf16547a120ac8928c45b4972409fc72..09d4631905f90f78772368ad71b11826877bdc34 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/selected_rows.h" #include "paddle/platform/device_context.h" +#define INLINE_FOR2(sizei, sizej) \ + for (int64_t i = 0; i < sizei; i++) \ + for (int64_t j = 0; j < sizej; j++) + namespace paddle { namespace operators { namespace math { @@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2); }; +namespace scatter { +// functors for manuplating SelectedRows data +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input); +}; + +template +struct Add { + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 + e_in2; + return out; + } +}; + +template +struct Mul { + // multiply two SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + auto e_in2 = framework::EigenVector::Flatten(input2.value()); + e_out.device(*context.eigen_device()) = e_in1 * e_in2; + return out; + } + // multiply scalar to SelectedRows + framework::SelectedRows operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const T input2) { + framework::SelectedRows out; + out.set_rows(input1.rows()); + out.set_height(input1.height()); + out.mutable_value()->mutable_data(input1.value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out.mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1.value()); + e_out.device(*context.eigen_device()) = input2 * e_in1; + return out; + } +}; + +enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; + +// out = seleted_rows_in / tensor +template +struct UpdateToTensor { + void operator()(const DeviceContext& context, const ScatterOps& op, + const framework::SelectedRows& input1, + framework::Tensor* input2); +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e5ef0740b6f385de7f17a3a419000cb8c897d986..b37269b471b4d71b42c41641fd14c7a64d2719d6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -116,9 +116,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { auto height = dout_tensor.dims()[0]; auto slice = dx_tensor.Slice(0, static_cast(height)); framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice); - if (dx_tensor.dims()[0] < height) { + if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( - static_cast(height), static_cast(dout_tensor.dims()[0])); + static_cast(height), static_cast(dx_tensor.dims()[0])); math::set_constant(dev_ctx, &rest_tensor, 0.0f); } } diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h index eaa36aa1aea53e0b37ef6c578d8bb1cda230ded0..552b48f608b7e0248f03dbea940a83f112a67712 100644 --- a/paddle/operators/sum_op.h +++ b/paddle/operators/sum_op.h @@ -37,11 +37,11 @@ class SumKernel : public framework::OpKernel { bool in_place = out_var == in_vars[0]; if (out_var->IsType()) { - auto *out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - + auto *out = context.Output("Out"); + if (!in_place) { + out->mutable_data(context.GetPlace()); + } auto result = EigenVector::Flatten(*out); - if (!in_place) { math::SetConstant constant_functor; constant_functor(context.template device_context(), out, diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 53e38ec70336ca7f2d7c142e5fb1bbe427ab2957..d5ff3e3fce29b1a888b2cd4d307c2655669e3e4c 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp { auto &x_array = x->Get(); auto *out = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out != nullptr, "Out must be set"); - auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { + auto *out_tensor = out->GetMutable(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); diff --git a/paddle/pybind/tensor_py.h b/paddle/pybind/tensor_py.h index 4d5e73e2c266b301de4f19e09be7ab4009c936d3..6b4290972bade585d1a0c2ae919a2e712bdf308c 100644 --- a/paddle/pybind/tensor_py.h +++ b/paddle/pybind/tensor_py.h @@ -77,10 +77,10 @@ struct CastToPyBufferImpl { } else if (paddle::platform::is_cpu_place(tensor.place())) { dst_tensor = tensor; } - return py::buffer_info( - dst_tensor.mutable_data(dst_tensor.place()), - sizeof(CUR_TYPE), py::format_descriptor::format(), - (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); + return py::buffer_info(dst_tensor.data(), sizeof(CUR_TYPE), + py::format_descriptor::format(), + (size_t)framework::arity(dst_tensor.dims()), + dims_outside, strides); } else { constexpr bool less = I + 1 < std::tuple_size>::value; return CastToPyBufferImpl()(tensor); diff --git a/python/paddle/v2/fluid/backward.py b/python/paddle/v2/fluid/backward.py index 6966cc75804b6b5a49ceb45a26994c23d2936bdb..f11c83f59c930784ca355acc3193c3c352db10e5 100644 --- a/python/paddle/v2/fluid/backward.py +++ b/python/paddle/v2/fluid/backward.py @@ -5,14 +5,17 @@ import collections __all__ = ['append_backward'] -def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, - end_idx=None): +def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None): + """ + Traverse all ops in op_descs[begin_idx : end_idx], + if any op has inputs/outputs named "old_name", rename it as 'new_name' + """ if begin_idx is None: begin_idx = 0 if end_idx is None: - end_idx = len(op_desc_list) + end_idx = len(op_descs) for i in range(begin_idx, end_idx): - op_desc = op_desc_list[i] + op_desc = op_descs[i] if isinstance(op_desc, tuple): op_desc = op_desc[0] op_desc.rename_input(old_name, new_name) @@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None, def _create_op_desc_(op_type, inputs, outputs, attrs): + """ + Create a C++ OpDesc object with specified inputs, outputs and attributes. + """ op_desc = core.OpDesc() op_desc.set_type(op_type) for para, args in inputs.iteritems(): @@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): return op_desc -def _infer_var_data_type_(var_name, block): - grad_var = block.desc.find_var(var_name.encode("ascii")) - fwd_name = _strip_grad_suffix_(var_name.encode("ascii")) +def _infer_var_data_type_(grad_var_name, block): + """ + Infer the data type of given grad variable + """ + grad_var = block.desc.find_var(grad_var_name.encode("ascii")) + fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii")) if block.desc.has_var_recursive(fwd_name): fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) grad_var.set_dtype(fwd_var.dtype()) @@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block): def _all_in_set_(cands, s): + """ + Test if all elements of 'cands' are in set 's' + """ for c in cands: if not c in s: return False @@ -52,18 +64,29 @@ def _all_in_set_(cands, s): def _strip_grad_suffix_(name): + """ + Strip the grad suffix from the given varibale name + e.g. x@GRAD ==> x + y@GRAD@RENAME@1 ==> y + """ pos = name.find(core.grad_var_suffix()) return name[:pos] if pos != -1 else name def _append_grad_suffix_(name): + """ + Append grad suffix to the given variable name + e.g. x ==> x@GRAD + """ return name + core.grad_var_suffix() def _addup_repetitive_outputs_(op_descs): - # In backward part, an variable my be the output of more than one ops. - # In this case, the variable should be the accumulation of all the outputs. - # We adopt adding `sum_op`s to implement the accumulate. + """ + In backward part, an variable may be the output of more than one ops. + In this case, the variable should be the accumulation of all the outputs. + `sum_op`s are added to implement the accumulate. + """ pending_sum_ops = [] var_rename_count = collections.defaultdict(int) renamed_vars = collections.defaultdict(list) @@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs): def _remove_no_grad_branch_(op_descs, no_grad_set): + """ + Remove unnecessary grad ops + A grad op can be removed in two cases: + 1. all outputs of the grad op are in 'no_grad_set' + 2. (TODO) all grad inputs of the grad op are in 'no_grad_set' + """ # Remove ops whose outputs are all in no_grad_dict op_descs = filter( lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set), @@ -133,6 +162,21 @@ def _append_backward_ops_(target, no_grad_dict, grad_to_var, callback=None): + """ + Create all grad ops, and insert them into given block + + Args: + target(Variable): the target variable of forward pass + block(Block): the block where forward ops are + target_block(Block): the block which is going to hold new generated grad ops + no_grad_dict(dict): + key(int) block index + val(set) a set of varibale names. These varibales have no gradient + grad_to_var(dict)(output argument): + key(str): grad variable name + val(str): corresponding forward variable name + """ + # grad_op_descs holds created grad_op, and will be appended to target_block grad_op_descs = [] program = block.program for op in reversed(block.ops): @@ -145,6 +189,7 @@ def _append_backward_ops_(target, no_grad_dict, grad_to_var, callback) grad_sub_block_list.append(grad_sub_block.desc) + # Getting op's corresponding grad_op grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, no_grad_dict[block.idx], grad_sub_block_list) grad_op_descs.extend(grad_op_desc) @@ -170,6 +215,20 @@ def _append_backward_ops_(target, def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): + """ + Create new variables required by backward pass. + + Args: + block(Block): the block where new variables will be created + start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created + grad_to_var(dict): + key(str): grad variable name + val(str): corresponding forward variable name + In most cases, this dict is generated by _append_backward_ops_() + grad_info_map(dict)(output argument): + key(str): forward variable name + val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index + """ for op_idx in range(start_op_idx, block.desc.op_size()): op_desc = block.desc.op(op_idx) if op_desc.has_attr("sub_block"): @@ -197,18 +256,18 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map): def append_backward(loss, parameter_list=None, no_grad_set=None): """ - Create and add gradient Operators in BlockDesc to compute - gradients of `loss` for parameters in parameter_list - - :param loss: an variable generated by cost function. - :type loss: Variable - :param no_grad_dict: variable that should not create gradient - :type no_grad_dict: set - :param parameter_list: parameters that need to compute gradient and - update to optimize the lost. - :type: list - :return: list of (parameters, gradients) pair. - :rtype: list[Variable] + Append backward part to main_program + + Args: + loss(Variable): The variable generated by cost function. + parameter_list(list): Parameters that need to be updated by optimizer. + If None, it means all parameters need to be updated. + no_grad_set(set): Variables that have no gradients in Block 0. + If None, the set will be generated inside the function and + contains all variables with `step_gradient=True` from all blocks. + + Return: + (list[Variable]): list of (parameters, gradients) pair. """ assert isinstance(loss, framework.Variable) diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index 2c91afb363bf72f2791e60c6df0d9130ccd698c5..1d6c594b41a2c295e3818fb119362d1daba1de33 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -1,12 +1,31 @@ import numpy as np +import contextlib +from framework import Program, default_main_program from . import core -from framework import Program, default_main_program, Parameter, Variable -__all__ = ['Executor', 'g_scope'] +__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope'] g_scope = core.Scope() +def global_scope(): + return g_scope + + +def switch_scope(scope): + global g_scope + ex = g_scope + g_scope = scope + return ex + + +@contextlib.contextmanager +def scope_guard(scope): + ex = switch_scope(scope) + yield + switch_scope(ex) + + def as_numpy(tensor): if isinstance(tensor, list): return [as_numpy(t) for t in tensor] @@ -117,7 +136,7 @@ class Executor(object): raise TypeError() if scope is None: - scope = g_scope + scope = global_scope() program = program.clone() global_block = program.global_block() diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index c3591a613acafb268a5bd70618cd4555450bac29..8acd470c5ed5fa8eeda396f1e9182db4ecdd7016 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -170,7 +170,7 @@ def main(): exe.run(fluid.default_startup_program()) - embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor() + embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor() embedding_param.set( load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place) diff --git a/python/paddle/v2/fluid/tests/decorators.py b/python/paddle/v2/fluid/tests/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..154619b0e93455922700a12d734967b4d20c4f13 --- /dev/null +++ b/python/paddle/v2/fluid/tests/decorators.py @@ -0,0 +1,29 @@ +import paddle.v2.fluid as fluid + +__all__ = ['many_times', 'prog_scope'] + + +def many_times(times): + def __impl__(fn): + def __fn__(*args, **kwargs): + for _ in range(times): + fn(*args, **kwargs) + + return __fn__ + + return __impl__ + + +def prog_scope(): + def __impl__(fn): + def __fn__(*args, **kwargs): + prog = fluid.Program() + startup_prog = fluid.Program() + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + with fluid.program_guard(prog, startup_prog): + fn(*args, **kwargs) + + return __fn__ + + return __impl__ diff --git a/python/paddle/v2/fluid/tests/test_adam_op.py b/python/paddle/v2/fluid/tests/test_adam_op.py index a0d6655d4cbcff8ed3d55df0f4e68fc6591fbb11..7dbc2fa0858a68c5da9e8d48dcb187494357e940 100644 --- a/python/paddle/v2/fluid/tests/test_adam_op.py +++ b/python/paddle/v2/fluid/tests/test_adam_op.py @@ -1,6 +1,8 @@ import unittest import numpy as np from op_test import OpTest +from paddle.v2.fluid import core +from paddle.v2.fluid.op import Operator class TestAdamOp1(OpTest): @@ -176,5 +178,124 @@ def adam_step(inputs, attributes): return param_out, moment1_out, moment2_out +def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad): + ''' + Simulate one step of the adam optimizer + :param inputs: dict of inputs + :param attributes: dict of attributes + :return tuple: tuple of output param, moment1, moment2, + beta1 power accumulator and beta2 power accumulator + ''' + param = inputs['Param'] + # grad = inputs['Grad'] + moment1 = inputs['Moment1'] + moment2 = inputs['Moment2'] + lr = inputs['LearningRate'] + beta1_pow = inputs['Beta1Pow'] + beta2_pow = inputs['Beta2Pow'] + + beta1 = attributes['beta1'] + beta2 = attributes['beta2'] + epsilon = attributes['epsilon'] + + moment1_out = np.zeros(shape=[height, row_numel]) + moment2_out = np.zeros(shape=[height, row_numel]) + param_out = np.zeros(shape=[height, row_numel]) + + for idx, row_id in enumerate(rows): + moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1 + ) * np_grad[idx] + moment2_out[row_id] = beta2 * moment2[row_id] + ( + 1 - beta2) * np.square(np_grad[idx]) + lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow) + param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / ( + np.sqrt(moment2_out[row_id]) + epsilon)) + return param_out, moment1_out, moment2_out + + +class TestSparseAdamOp(unittest.TestCase): + def setup(self, scope, place): + beta1 = 0.78 + beta2 = 0.836 + epsilon = 1e-4 + + height = 10 + rows = [0, 4, 7] + self.rows = rows + row_numel = 12 + self.row_numel = row_numel + self.dense_inputs = { + "Param": np.full((height, row_numel), 5.0).astype("float32"), + "Moment1": np.full((height, row_numel), 5.0).astype("float32"), + "Moment2": np.full((height, row_numel), 5.0).astype("float32"), + 'Beta1Pow': np.array([beta1**10]).astype("float32"), + 'Beta2Pow': np.array([beta2**10]).astype("float32"), + "LearningRate": np.full((1), 2.0).astype("float32") + } + self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2} + + grad_selected_rows = scope.var('Grad').get_selected_rows() + grad_selected_rows.set_height(height) + grad_selected_rows.set_rows(rows) + np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array[0, 0] = 2.0 + np_array[2, 8] = 4.0 + + grad_tensor = grad_selected_rows.get_tensor() + grad_tensor.set(np_array, place) + + self.sparse_inputs = ["Grad"] + + param_out, mom1, mom2 = adam_step_sparse( + self.dense_inputs, self.attrs, height, rows, row_numel, np_array) + self.outputs = { + "ParamOut": param_out, + "Moment1Out": mom1, + "Moment2Out": mom2 + } + + def check_with_place(self, place): + scope = core.Scope() + self.setup(scope, place) + + op_args = dict() + for key, np_array in self.dense_inputs.iteritems(): + var = scope.var(key).get_tensor() + var.set(np_array, place) + op_args[key] = key + for s in self.sparse_inputs: + op_args[s] = s + for s in self.outputs: + var = scope.var(s).get_tensor() + var.set(self.outputs[s], place) + op_args[s] = s + for k in self.attrs: + op_args[k] = self.attrs[k] + + # create and run sgd operator + adam_op = Operator("adam", **op_args) + adam_op.run(scope, place) + + for key, np_array in self.outputs.iteritems(): + out_var = scope.var(key).get_tensor() + actual = np.array(out_var) + actual = actual.reshape([actual.size]) + np_array = np_array.reshape([np_array.size]) + for idx, row_id in enumerate(self.rows): + j = 0 + while j < self.row_numel: + pos = row_id * self.row_numel + j + self.assertLess((actual[pos] - np_array[pos]) / actual[pos], + 0.00001) + j += 1 + + def test_sparse_sgd(self): + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.CUDAPlace(0)) + for place in places: + self.check_with_place(place) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py new file mode 100644 index 0000000000000000000000000000000000000000..c02c59284e1ca2e28ba2f6c5ec13b241c15fc288 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py @@ -0,0 +1,347 @@ +import numpy +import random +import collections +import paddle.v2.fluid as fluid +import unittest +from decorators import * + + +class Memory(object): + def __init__(self, shape, dtype='float32'): + self.ex = numpy.zeros(shape=shape, dtype=dtype) + self.cur = None + + def update(self, val): + assert val.shape == self.ex.shape + assert val.dtype == self.ex.dtype + self.cur = val + + def ex(self): + return self.ex + + def next(self): + self.ex = self.cur + self.cur = None + + def __next__(self): + self.next() + + def reset(self): + self.ex = numpy.zeros(shape=self.ex.shape, dtype=self.ex.dtype) + self.cur = None + + +class Output(object): + def __init__(self): + self.outs = [] + + def next_sequence(self): + self.outs.append([]) + + def out(self, val): + self.outs[-1].append(val) + + def last(self): + return self.outs[-1][-1] + + +class BaseRNN(object): + def __init__(self, ins, mems, params, outs, num_seq=5, max_seq_len=15): + self.num_seq = num_seq + self.inputs = collections.defaultdict(list) + + for _ in xrange(num_seq): + seq_len = random.randint(1, max_seq_len - 1) + for iname in ins: + ishape = ins[iname].get('shape', None) + idtype = ins[iname].get('dtype', 'float32') + lst = [] + for _ in xrange(seq_len): + lst.append(numpy.random.random(size=ishape).astype(idtype)) + self.inputs[iname].append(lst) + + self.mems = dict() + for mname in mems: + mshape = mems[mname].get('shape', None) + mdtype = mems[mname].get('dtype', 'float32') + self.mems[mname] = Memory(shape=mshape, dtype=mdtype) + + self.params = dict() + for pname in params: + pshape = params[pname].get('shape', None) + pdtype = params[pname].get('dtype', 'float32') + self.params[pname] = numpy.random.random(size=pshape).astype(pdtype) + + self.outputs = dict() + + for oname in outs: + self.outputs[oname] = Output() + + def step(self, **kwargs): + raise NotImplementedError() + + def exe(self): + retv = dict() + for out in self.outputs: + retv[out] = [] + + for seq_id in xrange(self.num_seq): + for mname in self.mems: + self.mems[mname].reset() + for out in self.outputs: + self.outputs[out].next_sequence() + + iname0 = self.inputs.keys()[0] + seq_len = len(self.inputs[iname0][seq_id]) + + for step_id in xrange(seq_len): + xargs = dict() + + for iname in self.inputs: + xargs[iname] = self.inputs[iname][seq_id][step_id] + + for mname in self.mems: + xargs[mname] = self.mems[mname] + + for pname in self.params: + xargs[pname] = self.params[pname] + + for out in self.outputs: + xargs[out] = self.outputs[out] + + self.step(**xargs) + + for mname in self.mems: + next(self.mems[mname]) + + for out in self.outputs: + retv[out].append(self.outputs[out].last()) + + for out in retv: + retv[out] = numpy.array(retv[out]) + return retv + + def to_feed(self, place): + feed_dict = dict() + + for iname in self.inputs: + lod = [0] + np_flatten = [] + for seq_id in xrange(len(self.inputs[iname])): + seq_len = len(self.inputs[iname][seq_id]) + lod.append(lod[-1] + seq_len) + np_flatten.extend(self.inputs[iname][seq_id]) + + t = fluid.Tensor() + t.set(numpy.array(np_flatten), place) + t.set_lod([lod]) + feed_dict[iname] = t + + for pname in self.params: + feed_dict[pname] = self.params[pname] + return feed_dict + + def get_numeric_gradient_of_param(self, param_name, delta=0.001): + p = self.params[param_name] + if len(p.shape) != 2: + raise ValueError("Not support get numeric gradient of an parameter," + " which is not matrix") + g = numpy.zeros(shape=p.shape, dtype=p.dtype) + + for i in xrange(p.shape[0]): + for j in xrange(p.shape[1]): + o = p[i][j] + p[i][j] += delta + pos = self._exe_mean_out_() + p[i][j] -= 2 * delta + neg = self._exe_mean_out_() + p[i][j] = o + g[i][j] = (pos - neg) / (delta * 2) + return g + + def get_numeric_gradient_of_input(self, + input_name, + delta=0.001, + return_one_tensor=True): + ipt = self.inputs[input_name] + grad = [] + + for seq in ipt: + seq_grad = [] + for item in seq: + item_grad = numpy.zeros(shape=item.shape, dtype=item.dtype) + if len(item.shape) != 1: + raise ValueError("Not support") + + for i in xrange(len(item)): + o = item[i] + item[i] += delta + pos = self._exe_mean_out_() + item[i] -= 2 * delta + neg = self._exe_mean_out_() + item[i] = o + item_grad[i] = (pos - neg) / (delta * 2) + seq_grad.append(item_grad) + grad.append(seq_grad) + + if not return_one_tensor: + return grad + + for i in xrange(len(grad)): + grad[i] = numpy.concatenate(grad[i]) + grad = numpy.concatenate(grad) + return grad + + def _exe_mean_out_(self): + outs = self.exe() + return numpy.array([o.mean() for o in outs.itervalues()]).mean() + + +class TestSimpleMul(unittest.TestCase): + DATA_NAME = 'X' + DATA_WIDTH = 32 + PARAM_NAME = 'W' + HIDDEN_WIDTH = 10 + OUT_NAME = 'Out' + + class SimpleMul(BaseRNN): + def __init__(self): + base = TestSimpleMul + super(base.SimpleMul, self).__init__({ + base.DATA_NAME: { + 'shape': [base.DATA_WIDTH] + } + }, {}, { + base.PARAM_NAME: { + 'shape': [base.DATA_WIDTH, base.HIDDEN_WIDTH] + } + }, [base.OUT_NAME]) + + def step(self, X, W, Out): + Out.out(numpy.matmul(X, W)) + + # Test many times in local to ensure the random seed cannot breaks CI + # @many_times(10) + @prog_scope() + def test_forward_backward(self): + py_rnn = TestSimpleMul.SimpleMul() + dat = fluid.layers.data( + name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1) + dat.stop_gradient = False + + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + d = rnn.step_input(dat) + o = fluid.layers.fc(input=d, + param_attr=self.PARAM_NAME, + bias_attr=False, + size=self.HIDDEN_WIDTH, + act=None) + rnn.output(o) + + out = rnn() + out = fluid.layers.sequence_pool(out, pool_type='last') + loss = fluid.layers.mean(x=out) + fluid.backward.append_backward(loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + out, w_g, i_g = map(numpy.array, + exe.run(feed=py_rnn.to_feed(cpu), + fetch_list=[ + out, self.PARAM_NAME + "@GRAD", + self.DATA_NAME + "@GRAD" + ], + return_numpy=False)) + out_by_python = py_rnn.exe()[self.OUT_NAME] + self.assertTrue(numpy.allclose(out, out_by_python)) + w_g_num = py_rnn.get_numeric_gradient_of_param(self.PARAM_NAME) + self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.05)) + i_g_num = py_rnn.get_numeric_gradient_of_input( + input_name=self.DATA_NAME) + i_g_num = i_g_num.reshape(i_g.shape) + self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05)) + + +class TestSimpleMulWithMemory(unittest.TestCase): + DATA_WIDTH = 32 + HIDDEN_WIDTH = 20 + DATA_NAME = 'X' + PARAM_NAME = 'W' + + class SimpleMulWithMemory(BaseRNN): + def __init__(self): + super(TestSimpleMulWithMemory.SimpleMulWithMemory, self).__init__({ + TestSimpleMulWithMemory.DATA_NAME: { + 'shape': [TestSimpleMulWithMemory.DATA_WIDTH] + } + }, {'Mem': { + 'shape': [TestSimpleMulWithMemory.HIDDEN_WIDTH] + }}, { + TestSimpleMulWithMemory.PARAM_NAME: { + 'shape': [ + TestSimpleMulWithMemory.DATA_WIDTH, + TestSimpleMulWithMemory.HIDDEN_WIDTH + ] + } + }, ['Out']) + + def step(self, X, Mem, W, Out): + o = numpy.matmul(X, W) + assert isinstance(Mem, Memory) + o += Mem.ex + Mem.update(o) + assert isinstance(Out, Output) + Out.out(o) + + # many_times used locally for debug. Make sure the calculation is stable. + # @many_times(10) + @prog_scope() + def test_forward_backward(self): + py_rnn = TestSimpleMulWithMemory.SimpleMulWithMemory() + data = fluid.layers.data( + name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1) + data.stop_gradient = False + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + d = rnn.step_input(data) + mem = rnn.memory(value=0.0, shape=[self.HIDDEN_WIDTH]) + hidden = fluid.layers.fc(input=d, + size=self.HIDDEN_WIDTH, + param_attr=self.PARAM_NAME, + bias_attr=False, + act=None) + o = fluid.layers.elementwise_add(x=hidden, y=mem) + rnn.update_memory(mem, o) + rnn.output(o) + + out = rnn() + last = fluid.layers.sequence_pool(input=out, pool_type='last') + loss = fluid.layers.mean(x=last) + fluid.backward.append_backward(loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + feed = py_rnn.to_feed(cpu) + last_np, w_g, i_g = map(numpy.array, + exe.run(feed=feed, + fetch_list=[ + last, self.PARAM_NAME + "@GRAD", + self.DATA_NAME + "@GRAD" + ], + return_numpy=False)) + last_by_py, = py_rnn.exe().values() + w_g_num = py_rnn.get_numeric_gradient_of_param(self.PARAM_NAME) + self.assertTrue(numpy.allclose(last_np, last_by_py)) + + self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.1)) + i_g_num = py_rnn.get_numeric_gradient_of_input(self.DATA_NAME) + i_g_num = i_g_num.reshape(i_g.shape) + + # Since this RNN has many float add. The number could be not stable. + # rtol = 0.1 + self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.1)) + + +if __name__ == '__main__': + unittest.main()