From 4132702362700fec4686e377d4c2707c063c18be Mon Sep 17 00:00:00 2001 From: wangjiang03 Date: Wed, 31 Aug 2016 05:04:46 +0000 Subject: [PATCH] Add RNN Tutorial. ISSUE=4589755 git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1442 1ad973e4-5ce8-4261-8a94-b56d1f490c56 --- demo/seqToseq/seqToseq_net.py | 10 +- doc/demo/index.md | 2 +- doc/demo/new_layer/{index.md => index.rst} | 468 ++++++++++-------- doc/ui/api/rnn/bi_lstm.jpg | 1 + .../rnn/encoder-decoder-attention-model.png | 1 + doc/ui/api/rnn/index.rst | 252 ++++++++++ doc/ui/index.md | 3 +- 7 files changed, 515 insertions(+), 222 deletions(-) rename doc/demo/new_layer/{index.md => index.rst} (51%) create mode 120000 doc/ui/api/rnn/bi_lstm.jpg create mode 120000 doc/ui/api/rnn/encoder-decoder-attention-model.png create mode 100644 doc/ui/api/rnn/index.rst diff --git a/demo/seqToseq/seqToseq_net.py b/demo/seqToseq/seqToseq_net.py index 08183e1f75d..479a64fa00d 100644 --- a/demo/seqToseq/seqToseq_net.py +++ b/demo/seqToseq/seqToseq_net.py @@ -58,7 +58,7 @@ def seq_to_seq_data(data_dir, args = {"src_dict": src_dict, "trg_dict": trg_dict}) - return {"src_dict_path": src_lang_dict, "trg_dict_path": trg_lang_dict, + return {"src_dict_path": src_lang_dict, "trg_dict_path": trg_lang_dict, "gen_result": gen_result} @@ -88,11 +88,11 @@ def gru_encoder_decoder(data_conf, src_embedding = embedding_layer( input=src_word_id, size=word_vector_dim, - param_attr=ParamAttr(name='_source_language_embedding'), ) - src_forward = simple_gru(input=src_embedding, size=encoder_size, ) + param_attr=ParamAttr(name='_source_language_embedding')) + src_forward = simple_gru(input=src_embedding, size=encoder_size) src_backward = simple_gru(input=src_embedding, size=encoder_size, - reverse=True, ) + reverse=True) encoded_vector = concat_layer(input=[src_forward, src_backward]) with mixed_layer(size=decoder_size) as encoded_proj: @@ -147,7 +147,7 @@ def gru_encoder_decoder(data_conf, is_seq=True), StaticInput(input=encoded_proj, is_seq=True), trg_embedding - ], ) + ]) lbl = data_layer(name='target_language_next_word', size=target_dict_dim) diff --git a/doc/demo/index.md b/doc/demo/index.md index 5ca0e56dae0..4d0e4554cb4 100644 --- a/doc/demo/index.md +++ b/doc/demo/index.md @@ -21,4 +21,4 @@ There are serveral examples and demos here. * [Embedding: Chinese Word](embedding_model/index.md) ## Customization -* [Writing New Layers](new_layer/index.md) +* [Writing New Layers](new_layer/index.rst) diff --git a/doc/demo/new_layer/index.md b/doc/demo/new_layer/index.rst similarity index 51% rename from doc/demo/new_layer/index.md rename to doc/demo/new_layer/index.rst index 0ab74264da9..0f12e952ad4 100644 --- a/doc/demo/new_layer/index.md +++ b/doc/demo/new_layer/index.rst @@ -1,38 +1,63 @@ Writing New Layers ----------- +======================= + This tutorial will guide you to write customized layers in PaddlePaddle. We will utilize fully connected layer as an example to guide you through the following steps for writing a new layer. + - Derive equations for the forward and backward part of the layer. - Implement C++ class for the layer. - Write gradient check unit test to make sure the gradients are correctly computed. - Implement Python wrapper for the layer. -## Derive Equations +================= +Derive Equations +================= + First we need to derive equations of the *forward* and *backward* part of the layer. The forward part computes the output given an input. The backward part computes the gradients of the input and the parameters given the the gradients of the output. The illustration of a fully connected layer is shown in the following figure. In a fully connected layer, all output nodes are connected to all the input nodes. -
![](./FullyConnected.jpg)
+ +.. image:: ./FullyConnected.jpg The *forward part* of a layer transforms an input into the corresponding output. -Fully connected layer takes a dense input vector with dimension $D_i$. It uses a transformation matrix $W$ with size $D_i \times D_o$ to project x into a $D_o$ dimensional vector, and add a bias vector $b$ with dimension $D_o$ to the vector. -\[y = f(W^T x + b) \] -where $f(.)$ is an nonlinear *activation* function, such as sigmoid, tanh, and Relu. +Fully connected layer takes a dense input vector with dimension :math:`D_i`. It uses a transformation matrix :math:`W` with size :math:`D_i \times D_o` to project x into a :math:`D_o` dimensional vector, and add a bias vector :math:`b` with dimension :math:`D_o` to the vector. + +.. math:: + + y = f(W^T x + b) + +where :math:`f(.)` is an nonlinear *activation* function, such as sigmoid, tanh, and Relu. + +The transformation matrix :math:`W` and bias vector :math:`b` are the *parameters* of the layer. The *parameters* of a layer are learned during training in the *backward pass*. The backward pass computes the gradients of the output function with respect to all parameters and inputs. The optimizer can use chain rule to compute the gradients of the loss function with respect to each parameter. Suppose our loss function is :math:`c(y)`, then + +.. math:: + + \frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} + +Suppose :math:`z = f(W^T x + b)`, then + +.. math:: + + \frac{\partial y}{\partial z} = \frac{\partial f(z)}{\partial z} + -The transformation matrix $W$ and bias vector $b$ are the *parameters* of the layer. The *parameters* of a layer are learned during training in the *backward pass*. The backward pass computes the gradients of the output function with respect to all parameters and inputs. The optimizer can use chain rule to compute the gradients of the loss function with respect to each parameter. Suppose our loss function is $c(y)$, then -\[\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x} \] +This derivative can be automatically computed by our base layer class. -Suppose $z = f(W^T x + b)$, then -\[ \frac{\partial y}{\partial z} = \frac{\partial f(z)}{\partial z}\] - This derivative can be automatically computed by our base layer class. +Then, for fully connected layer, we need to compute :math:`\frac{\partial z}{\partial x}`, and :math:`\frac{\partial z}{\partial W}`, and :math:`\frac{\partial z}{\partial b}` +. -Then, for fully connected layer, we need to compute $\frac{\partial z}{\partial x}$, and $\frac{\partial z}{\partial W}$, and $\frac{\partial z}{\partial b}$. -\[ \frac{\partial z}{\partial x} = W \] -\[ \frac{\partial z_j}{\partial W_{ij}} = x_i \] -\[ \frac{\partial z}{\partial b} = \mathbf 1 \] -where $\mathbf 1$ is an all one vector, $W_{ij}$ is the number at the i-th row and j-th column of the matrix $W$, $z_j$ is the j-th component of the vector $z$, and $x_i$ is the i-th component of the vector $x$. +.. math:: + \frac{\partial z}{\partial x} = W \\ + \frac{\partial z_j}{\partial W_{ij}} = x_i \\ + \frac{\partial z}{\partial b} = \mathbf 1 \\ -Then we can use chain rule to calculate $\frac{\partial z}{\partial x}$, and $\frac{\partial z}{\partial W}$. The details of the computation will be given in the next section. +where .. math::`\mathbf 1` is an all one vector, .. math::`W_{ij}` is the number at the i-th row and j-th column of the matrix .. math::`W`, .. math::`z_j` is the j-th component of the vector .. math::`z`, and .. math::`x_i` is the i-th component of the vector .. math::`x`. + +Then we can use chain rule to calculate .. math::`\frac{\partial z}{\partial x}`, and .. math::`\frac{\partial z}{\partial W}`. The details of the computation will be given in the next section. + +================= +Implement C++ Class +================= -## Implement C++ Class The C++ class of the layer implements the initialization, forward, and backward part of the layer. The fully connected layer is at `paddle/gserver/layers/FullyConnectedLayer.h` and `paddle/gserver/layers/FullyConnectedLayer.cpp`. We list simplified version of the code below. It needs to derive the base class `paddle::BaseLayer`, and it needs to override the following functions: @@ -44,38 +69,37 @@ It needs to derive the base class `paddle::BaseLayer`, and it needs to override - `prefetch`. It is utilized to determine the rows corresponding parameter matrix to prefetch from parameter server. You do not need to override this function if your layer does not need remote sparse update. (most layers do not need to support remote sparse update) -The header file is listed below: +The header file is listed below:: -```C -namespace paddle { -/** - * A layer has full connections to all neurons in the previous layer. - * It computes an inner product with a set of learned weights, and - * (optionally) adds biases. - * - * The config file api is fc_layer. - */ + namespace paddle { + /** + * A layer has full connections to all neurons in the previous layer. + * It computes an inner product with a set of learned weights, and + * (optionally) adds biases. + * + * The config file api is fc_layer. + */ -class FullyConnectedLayer : public Layer { -protected: - WeightList weights_; - std::unique_ptr biases_; + class FullyConnectedLayer : public Layer { + protected: + WeightList weights_; + std::unique_ptr biases_; -public: - explicit FullyConnectedLayer(const LayerConfig& config) - : Layer(config) {} - ~FullyConnectedLayer() {} + public: + explicit FullyConnectedLayer(const LayerConfig& config) + : Layer(config) {} + ~FullyConnectedLayer() {} - bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - Weight& getWeight(int idx) { return *weights_[idx]; } + Weight& getWeight(int idx) { return *weights_[idx]; } + + void prefetch(); + void forward(PassType passType); + void backward(const UpdateCallback& callback = nullptr); + }; + } // namespace paddle - void prefetch(); - void forward(PassType passType); - void backward(const UpdateCallback& callback = nullptr); -}; -} // namespace paddle -``` It defines the parameters as class variables. We use `Weight` class as abstraction of parameters. It supports multi-thread update. The details of this class will be described in details in the implementations. - `weights_` is a list of weights for the transformation matrices. The current implementation can have more than one inputs. Thus, it has a list of weights. One weight corresponds to an input. @@ -85,86 +109,84 @@ The fully connected layer does not have layer configuration hyper-parameters. If The following code snippet implements the `init` function. - First, every `init` function must call the `init` function of the base class `Layer::init(layerMap, parameterMap);`. This statement will initialize the required variables and connections for each layer. -- The it initializes all the weights matrices $W$. The current implementation can have more than one inputs. Thus, it has a list of weights. +- The it initializes all the weights matrices :math:`W`. The current implementation can have more than one inputs. Thus, it has a list of weights. - Finally, it initializes the bias. -```C -bool FullyConnectedLayer::init(const LayerMap& layerMap, - const ParameterMap& parameterMap) { - /* Initialize the basic parent class */ - Layer::init(layerMap, parameterMap); - - /* initialize the weightList */ - CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - // Option the parameters - size_t height = inputLayers_[i]->getSize(); - size_t width = getSize(); - - // create a new weight - if (parameters_[i]->isSparse()) { - CHECK_LE(parameters_[i]->getSize(), width * height); - } else { - CHECK_EQ(parameters_[i]->getSize(), width * height); - } - Weight* w = new Weight(height, width, parameters_[i]); +The code is listed below:: - // append the new weight to the list - weights_.emplace_back(w); - } + bool FullyConnectedLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); - /* initialize biases_ */ - if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, getSize(), biasParameter_)); - } + /* initialize the weightList */ + CHECK(inputLayers_.size() == parameters_.size()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + // Option the parameters + size_t height = inputLayers_[i]->getSize(); + size_t width = getSize(); - return true; -} + // create a new weight + if (parameters_[i]->isSparse()) { + CHECK_LE(parameters_[i]->getSize(), width * height); + } else { + CHECK_EQ(parameters_[i]->getSize(), width * height); + } + Weight* w = new Weight(height, width, parameters_[i]); -``` + // append the new weight to the list + weights_.emplace_back(w); + } + + /* initialize biases_ */ + if (biasParameter_.get() != NULL) { + biases_ = std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + } + return true; + } The implementation of the forward part has the following steps. - Every layer must call `Layer::forward(passType);` at the beginning of its `forward` function. - Then it allocates memory for the output using `reserveOutput(batchSize, size);`. This step is necessary because we support the batches to have different batch sizes. `reserveOutput` will change the size of the output accordingly. For the sake of efficiency, we will allocate new memory if we want to expand the matrix, but we will reuse the existing memory block if we want to shrink the matrix. -- Then it computes $\sum_i W_i x + b$ using Matrix operations. `getInput(i).value` retrieve the matrix of the i-th input. Each input is a $batchSize \times dim$ matrix, where each row represents an single input in a batch. For a complete lists of supported matrix operations, please refer to `paddle/math/Matrix.h` and `paddle/math/BaseMatrix.h`. +- Then it computes :math:`\sum_i W_i x + b` using Matrix operations. `getInput(i).value` retrieve the matrix of the i-th input. Each input is a :math:`batchSize \times dim` matrix, where each row represents an single input in a batch. For a complete lists of supported matrix operations, please refer to `paddle/math/Matrix.h` and `paddle/math/BaseMatrix.h`. - Finally it applies the activation function using `forwardActivation();`. It will automatically applies the corresponding activation function specifies in the network configuration. +The code is listed below:: -```C -void FullyConnectedLayer::forward(PassType passType) { - Layer::forward(passType); + void FullyConnectedLayer::forward(PassType passType) { + Layer::forward(passType); - /* malloc memory for the output_ if necessary */ - int batchSize = getInput(0).getBatchSize(); - int size = getSize(); + /* malloc memory for the output_ if necessary */ + int batchSize = getInput(0).getBatchSize(); + int size = getSize(); - { - // Settup the size of the output. - reserveOutput(batchSize, size); - } + { + // Settup the size of the output. + reserveOutput(batchSize, size); + } - MatrixPtr outV = getOutputValue(); + MatrixPtr outV = getOutputValue(); - // Apply the the transformation matrix to each input. - for (size_t i = 0; i != inputLayers_.size(); ++i) { - auto input = getInput(i); - CHECK(input.value) << "The input of 'fc' layer must be matrix"; - i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0) - : outV->mul(input.value, weights_[i]->getW(), 1, 1); - } + // Apply the the transformation matrix to each input. + for (size_t i = 0; i != inputLayers_.size(); ++i) { + auto input = getInput(i); + CHECK(input.value) << "The input of 'fc' layer must be matrix"; + i == 0 ? outV->mul(input.value, weights_[i]->getW(), 1, 0) + : outV->mul(input.value, weights_[i]->getW(), 1, 1); + } - /* add the bias-vector */ - if (biases_.get() != NULL) { - outV->addBias(*(biases_->getW()), 1); - } + /* add the bias-vector */ + if (biases_.get() != NULL) { + outV->addBias(*(biases_->getW()), 1); + } + + /* activation */ { + forwardActivation(); + } + } - /* activation */ { - forwardActivation(); - } -} -``` The implementation of the backward part has the following steps. - ` backwardActivation();` computes the gradients of the activation. The gradients will be multiplies in place to the gradients of the output, which can be retrieved using `getOutputGrad()`. @@ -172,88 +194,93 @@ The implementation of the backward part has the following steps. - Then it computes the gradients of the transformation matrices and inputs, and it calls `incUpdate` for the corresponding parameter. This gives the framework the chance to know whether it has gathered all the gradient to one parameter so that it can do some overlapping work (e.g., network communication) -```C -void FullyConnectedLayer::backward(const UpdateCallback& callback) { - /* Do derivation for activations.*/ { - backwardActivation(); - } +The code is listed below:: - if (biases_ && biases_->getWGrad()) { - biases_->getWGrad()->collectBias(*getOutputGrad(), 1); + void FullyConnectedLayer::backward(const UpdateCallback& callback) { + /* Do derivation for activations.*/ { + backwardActivation(); + } - /* Increasing the number of gradient */ - biases_->getParameterPtr()->incUpdate(callback); - } + if (biases_ && biases_->getWGrad()) { + biases_->getWGrad()->collectBias(*getOutputGrad(), 1); - bool syncFlag = hl_get_sync_flag(); + /* Increasing the number of gradient */ + biases_->getParameterPtr()->incUpdate(callback); + } - for (size_t i = 0; i != inputLayers_.size(); ++i) { - /* Calculate the W-gradient for the current layer */ - if (weights_[i]->getWGrad()) { - MatrixPtr input_T = getInputValue(i)->getTranspose(); - MatrixPtr oGrad = getOutputGrad(); - { - weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1); + bool syncFlag = hl_get_sync_flag(); + + for (size_t i = 0; i != inputLayers_.size(); ++i) { + /* Calculate the W-gradient for the current layer */ + if (weights_[i]->getWGrad()) { + MatrixPtr input_T = getInputValue(i)->getTranspose(); + MatrixPtr oGrad = getOutputGrad(); + { + weights_[i]->getWGrad()->mul(input_T, oGrad, 1, 1); + } + } + + + /* Calculate the input layers error */ + MatrixPtr preGrad = getInputGrad(i); + if (NULL != preGrad) { + MatrixPtr weights_T = weights_[i]->getW()->getTranspose(); + preGrad->mul(getOutputGrad(), weights_T, 1, 1); + } + + { + weights_[i]->getParameterPtr()->incUpdate(callback); + } } } - /* Calculate the input layers error */ - MatrixPtr preGrad = getInputGrad(i); - if (NULL != preGrad) { - MatrixPtr weights_T = weights_[i]->getW()->getTranspose(); - preGrad->mul(getOutputGrad(), weights_T, 1, 1); - } +The `prefetch` function specifies the rows that need to be fetched from parameter server during training. It is only useful for remote sparse training. In remote sparse training, the full parameter matrix is stored distributedly at the parameter server. When the layer uses a batch for training, only a subset of locations of the input is non-zero in this batch. Thus, this layer only needs the rows of the transformation matrix corresponding to the locations of these non-zero entries. The `prefetch` function specifies the ids of these rows. + +Most of the layers do not need remote sparse training function. You do not need to override this function in this case:: - { - weights_[i]->getParameterPtr()->incUpdate(callback); + void FullyConnectedLayer::prefetch() { + for (size_t i = 0; i != inputLayers_.size(); ++i) { + auto* sparseParam = + dynamic_cast(weights_[i]->getW().get()); + if (sparseParam) { + MatrixPtr input = getInputValue(i); + sparseParam->addRows(input); + } + } } - } -} -``` -The `prefetch` function specifies the rows that need to be fetched from parameter server during training. It is only useful for remote sparse training. In remote sparse training, the full parameter matrix is stored distributedly at the parameter server. When the layer uses a batch for training, only a subset of locations of the input is non-zero in this batch. Thus, this layer only needs the rows of the transformation matrix corresponding to the locations of these non-zero entries. The `prefetch` function specifies the ids of these rows. -Most of the layers do not need remote sparse training function. You do not need to override this function in this case. +Finally, you can use `REGISTER_LAYER(fc, FullyConnectedLayer);` to register the layer. `fc` is the identifier of the layer, and `FullyConnectedLayer` is the class name of the layer:: -```C -void FullyConnectedLayer::prefetch() { - for (size_t i = 0; i != inputLayers_.size(); ++i) { - auto* sparseParam = - dynamic_cast(weights_[i]->getW().get()); - if (sparseParam) { - MatrixPtr input = getInputValue(i); - sparseParam->addRows(input); + namespace paddle { + REGISTER_LAYER(fc, FullyConnectedLayer); } - } -} -``` -Finally, you can use `REGISTER_LAYER(fc, FullyConnectedLayer);` to register the layer. `fc` is the identifier of the layer, and `FullyConnectedLayer` is the class name of the layer. - -```C -namespace paddle { -REGISTER_LAYER(fc, FullyConnectedLayer); -} -``` If the `cpp` file is put into `paddle/gserver/layers`, it will be automatically added to the compilation list. +================= +Write Gradient Check Unit Test +================= -## Write Gradient Check Unit Test -An easy way to verify the correctness of new layer's implementation is to write a gradient check unit test. Gradient check unit test utilizes finite difference method to verify the gradient of a layer. It modifies the input with a small perturbation $\Delta x$ and observes the changes of output $\Delta y$, the gradient can be computed as $\frac{\Delta y}{\Delta x }$. This gradient can be compared with the gradient computed by the `backward` function of the layer to ensure the correctness of the gradient computation. Notice that the gradient check only tests the correctness of the gradient computation, it does not necessarily guarantee the correctness of the implementation of the `forward` and `backward` function. You need to write more sophisticated unit tests to make sure your layer is implemented correctly. +An easy way to verify the correctness of new layer's implementation is to write a gradient check unit test. Gradient check unit test utilizes finite difference method to verify the gradient of a layer. It modifies the input with a small perturbation :math:`\Delta x` and observes the changes of output :math:`\Delta y`, the gradient can be computed as :math:`\frac{\Delta y}{\Delta x }`. This gradient can be compared with the gradient computed by the `backward` function of the layer to ensure the correctness of the gradient computation. Notice that the gradient check only tests the correctness of the gradient computation, it does not necessarily guarantee the correctness of the implementation of the `forward` and `backward` function. You need to write more sophisticated unit tests to make sure your layer is implemented correctly. All the gradient check unit tests are located in `paddle/gserver/tests/test_LayerGrad.cpp`. You are recommended to put your test into a new test file if you are planning to write a new layer. The gradient test of the gradient check unit test of the fully connected layer is listed below. It has the following steps. + Create layer configuration. A layer configuration can include the following attributes: + - size of the bias parameter. (4096 in our example) - type of the layer. (fc in our example) - size of the layer. (4096 in our example) - activation type. (softmax in our example) - dropout rate. (0.1 in our example) + + configure the input of the layer. In our example, we have only one input. + - type of the input (`INPUT_DATA`) in our example. It can be one of the following types + - `INPUT_DATA`: dense vector. - `INPUT_LABEL`: integer. - `INPUT_DATA_TARGET`: dense vector, but it does not used to compute gradient. @@ -262,81 +289,91 @@ All the gradient check unit tests are located in `paddle/gserver/tests/test_Laye - `INPUT_SEQUENCE_LABEL`: integer with sequence information. - `INPUT_SPARSE_NON_VALUE_DATA`: 0-1 sparse data. - `INPUT_SPARSE_FLOAT_VALUE_DATA`: float sparse data. + - name of the input. (`layer_0` in our example) - size of the input. (8192 in our example) - number of non-zeros, only useful for sparse inputs. - format of sparse data, only useful for sparse inputs. + + each inputs needs to call `config.layerConfig.add_inputs();` once. + call `testLayerGrad` to perform gradient checks. It has the following arguments. + - layer and input configurations. (`config` in our example) - type of the input. (`fc` in our example) - batch size of the gradient check. (100 in our example) - whether the input is transpose. Most layers need to set it to `false`. (`false` in our example) - whether to use weights. Some layers or activations perform normalization so that the sum of their output is a constant. For example, the sum of output of a softmax activation is one. In this case, we cannot correctly compute the gradients using regular gradient check techniques. A weighted sum of the output, which is not a constant, is utilized to compute the gradients. (`true` in our example, because the activation of a fully connected layer can be softmax) -```C -void testFcLayer(string format, size_t nnz) { - // Create layer configuration. - TestConfig config; - config.biasSize = 4096; - config.layerConfig.set_type("fc"); - config.layerConfig.set_size(4096); - config.layerConfig.set_active_type("sigmoid"); - config.layerConfig.set_drop_rate(0.1); - // Setup inputs. - config.inputDefs.push_back( - {INPUT_DATA, "layer_0", 8192, nnz, ParaSparse(format)}); - config.layerConfig.add_inputs(); - LOG(INFO) << config.inputDefs[0].sparse.sparse << " " - << config.inputDefs[0].sparse.format; - for (auto useGpu : {false, true}) { - testLayerGrad(config, "fc", 100, /* trans */ false, useGpu, - /* weight */ true); - } -} -``` + +The code is listed below:: + + void testFcLayer(string format, size_t nnz) { + // Create layer configuration. + TestConfig config; + config.biasSize = 4096; + config.layerConfig.set_type("fc"); + config.layerConfig.set_size(4096); + config.layerConfig.set_active_type("sigmoid"); + config.layerConfig.set_drop_rate(0.1); + // Setup inputs. + config.inputDefs.push_back( + {INPUT_DATA, "layer_0", 8192, nnz, ParaSparse(format)}); + config.layerConfig.add_inputs(); + LOG(INFO) << config.inputDefs[0].sparse.sparse << " " + << config.inputDefs[0].sparse.format; + for (auto useGpu : {false, true}) { + testLayerGrad(config, "fc", 100, /* trans */ false, useGpu, + /* weight */ true); + } + } + If you are creating a new file for the test, such as `paddle/gserver/tests/testFCGrad.cpp`, you need to add the file to `paddle/gserver/tests/CMakeLists.txt`. An example is given below. All the unit tests will run when you execute the command `make tests`. Notice that some layers might need high accuracy for the gradient check unit tests to work well. You need to configure `WITH_DOUBLE` to `ON` when configuring cmake. -``` -add_unittest_without_exec(test_FCGrad - test_FCGrad.cpp - LayerGradUtil.cpp - TestUtil.cpp) +The code is listed below:: -add_test(NAME test_FCGrad - COMMAND test_FCGrad) -``` + add_unittest_without_exec(test_FCGrad + test_FCGrad.cpp + LayerGradUtil.cpp + TestUtil.cpp) -## Implement Python Wrapper + add_test(NAME test_FCGrad + COMMAND test_FCGrad) + + +================= +Implement Python Wrapper +================= Implementing Python wrapper allows us to use the added layer in configuration files. All the Python wrappers are in file `python/paddle/trainer/config_parser.py`. An example of the Python wrapper for fully connected layer is listed below. It has the following steps: + - Use `@config_layer('fc’)` at the decorator for all the Python wrapper class. `fc` is the identifier of the layer. - Implements `__init__` constructor function. + - It first call `super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs)` base constructor function. `FCLayer` is the Python wrapper class name, and `fc` is the layer identifier name. They must be correct in order for the wrapper to work. - Then it computes the size and format (whether sparse) of each transformation matrix as well as the size. +The code is listed below:: + + @config_layer('fc') + class FCLayer(LayerBase): + def __init__( + self, + name, + size, + inputs, + bias=True, + **xargs): + super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) + for input_index in xrange(len(self.inputs)): + input_layer = self.get_input_layer(input_index) + psize = self.config.size * input_layer.size + dims = [input_layer.size, self.config.size] + format = self.inputs[input_index].format + sparse = format == "csr" or format == "csc" + if sparse: + psize = self.inputs[input_index].nnz + self.create_input_parameter(input_index, psize, dims, sparse, format) + self.create_bias_parameter(bias, self.config.size) -```python -@config_layer('fc') -class FCLayer(LayerBase): - def __init__( - self, - name, - size, - inputs, - bias=True, - **xargs): - super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) - for input_index in xrange(len(self.inputs)): - input_layer = self.get_input_layer(input_index) - psize = self.config.size * input_layer.size - dims = [input_layer.size, self.config.size] - format = self.inputs[input_index].format - sparse = format == "csr" or format == "csc" - if sparse: - psize = self.inputs[input_index].nnz - self.create_input_parameter(input_index, psize, dims, sparse, format) - self.create_bias_parameter(bias, self.config.size) -``` In network configuration, the layer can be specifies using the following code snippets. The arguments of this class are: - `name` is the name identifier of the layer instance. @@ -345,14 +382,15 @@ In network configuration, the layer can be specifies using the following code sn - `bias` specifies whether this layer instance has bias. - `inputs` specifies a list of layer instance names as inputs. -```python -Layer( - name = "fc1", - type = "fc", - size = 64, - bias = True, - inputs = [Input("pool3")] -) -``` +The code is listed below:: + + Layer( + name = "fc1", + type = "fc", + size = 64, + bias = True, + inputs = [Input("pool3")] + ) + You are also recommended to implement a helper for the Python wrapper, which makes it easier to write models. You can refer to `python/paddle/trainer_config_helpers/layers.py` for examples. diff --git a/doc/ui/api/rnn/bi_lstm.jpg b/doc/ui/api/rnn/bi_lstm.jpg new file mode 120000 index 00000000000..ae94882e70e --- /dev/null +++ b/doc/ui/api/rnn/bi_lstm.jpg @@ -0,0 +1 @@ +../../../demo/sentiment_analysis/bi_lstm.jpg \ No newline at end of file diff --git a/doc/ui/api/rnn/encoder-decoder-attention-model.png b/doc/ui/api/rnn/encoder-decoder-attention-model.png new file mode 120000 index 00000000000..f83c45bda88 --- /dev/null +++ b/doc/ui/api/rnn/encoder-decoder-attention-model.png @@ -0,0 +1 @@ +../../../demo/text_generation/encoder-decoder-attention-model.png \ No newline at end of file diff --git a/doc/ui/api/rnn/index.rst b/doc/ui/api/rnn/index.rst new file mode 100644 index 00000000000..4f636a431d2 --- /dev/null +++ b/doc/ui/api/rnn/index.rst @@ -0,0 +1,252 @@ +Recurrent Neural Network Configuration +===================================== +This tutorial will guide you how to configure recurrent neural network in PaddlePaddle. PaddlePaddle supports highly flexible and efficient recurrent neural network configuration. In this tutorial, you will learn how to: +- prepare sequence data for learning recurrent neural networks. +- configure recurrent neural network architecture. +- generate sequence with learned recurrent neural network models. + + + + +We will use vanilla recurrent neural network, and sequence to sequence model to guide you through these steps. The code of sequence to sequence model can be found at `/demo/seqToseq`. + + +================= +Prepare Sequence Data +================= +PaddlePaddle does not need any preprocessing to sequence data, such as padding. The only thing that needs to be done is to set the type of the corresponding type to input. For example, the following code snippets defines three input. All of them are sequences, and the size of them are `src_dict`, `trg_dict`, and `trg_dict`:: + + settings.slots = [ + integer_value_sequence(len(settings.src_dict)), + integer_value_sequence(len(settings.trg_dict)), + integer_value_sequence(len(settings.trg_dict)) + ] + + +Then at the `process` function, each `yield` function will return three integer lists. Each integer list is treated as a sequence of integers:: + + + yield src_ids, trg_ids, trg_ids_next + + +For more details description of how to write a data provider, please refer to :doc:`Python Data Provider <../py_data_provider_wrapper>`. The full data provider file is located at `./demo/seqToseq/dataprovider.py`. + +================= +Configure Recurrent Neural Network Architecture +================= + +---------------- +Simple Gated Recurrent Neural Network +---------------- +Recurrent neural network process a sequence at each time step sequentially. An example of the architecture of LSTM is listed below. + +.. image:: ./bi_lstm.jpg + :align: center + + +Generally speaking, a recurrent network perform the following operations from t=1 to t=T, or reversely from t=T to t=1. + +.. math:: + + x_{t+1} = f_x(x_t), y_t = f_y(x_t) + + +where :math:`f_x(.)` is called *step function*, and :math:`f_y(.)` is called *output function*. In vanilla recurrent neural network, both of the step function and output function are very simple. However, PaddlePaddle supports the configuration of very complex architectures by modifying these two functions. We will use the sequence to sequence model with attention as an example to demonstrate how you can configure complex recurrent neural network models. In this section, we will use a simple vanilla recurrent neural network as an example of configuring simple recurrent neural network using `recurrent_group`. Notice that if you only need to use simple RNN, GRU, or LSTM, then `grumemory` and `lstmemory` is recommended because they are more computationally efficient than `recurrent_group`. + + +For vanilla RNN, at each time step, the *step function* is: + + +.. math:: + + x_{t+1} = W_x x_t + W_i I_t + b + +where :math:`x_t` is the RNN state, and :math:`I_t` is the input, :math:`W_x` and :math:`W_i` are transformation matrices for RNN states and inputs, respectively. :math:`b` is the bias. +Its *output function* simply takes :math:`x_t` as the output. + + +`recurrent_group` is the most important tools for constructing recurrent neural networks. It defines the *step function*, *output function* and the inputs of the recurrent neural network. Notice that the `step` argument of this function implements both the `step function` and the `output function`:: + + + def simple_rnn(input, + size=None, + name=None, + reverse=False, + rnn_bias_attr=None, + act=None, + rnn_layer_attr=None): + def __rnn_step__(ipt): + out_mem = memory(name=name, size=size) + rnn_out = mixed_layer(input = [full_matrix_projection(ipt), + full_matrix_projection(out_mem)], + name = name, + bias_attr = rnn_bias_attr, + act = act, + layer_attr = rnn_layer_attr, + size = size) + return rnn_out + return recurrent_group(name='%s_recurrent_group' % name, + step=__rnn_step__, + reverse=reverse, + input=input) + + +PaddlePaddle uses memory to construct step function. *Memory* is the most important concept when constructing recurrent neural networks in PaddlePaddle. A memory is a state that is used recurrently in step functions, such as , :math:`x_{t+1} = f_x(x_t)`. One memory contains an *output* and a *input*. The output of memory at the current time step is utilized as the input of the memory at the next time step. A memory can also has a *boot layer*, whose output is utilized as the initial value of the memory. In our case, the output of the gated recurrent unit is employed as the output memory. Notice that the name of the layer `rnn_out` is the same as the name of `out_mem`. This means the output of the layer `rnn_out` (:math:`x_{t+1}`) is utilized as the *output* of `out_mem` memory. + +A memory can also be a sequence. In this case, at each time step, we have a sequence as the state of the recurrent neural network. This can be useful when constructing very complex recurrent neural network. Other advanced functions include defining multiple memories, and defining hierarchical recurrent neural network architecture using sub-sequence. + + +We return `rnn_out` at the end of the function. It means that the output of the layer `rnn_out` is utilized as the *output* function of the gated recurrent neural network. + + + +---------------- +Sequence to Sequence Model with Attention +---------------- +We will use the sequence to sequence model with attention as an example to demonstrate how you can configure complex recurrent neural network models. An illustration of the sequence to sequence model with attention is shown in the following figure. + +.. image:: ./encoder-decoder-attention-model.png + :align: center + + +In this model, the source sequence :math:`S = \{s_1, \dots, s_T\}` is encoded with a bidirectional gated recurrent neural networks. The hidden states of the bidirectional gated recurrent neural network :math:`H_S = \{H_S1, \dots, H_sT\}` is called *encoder vector* The decoder is a gated recurrent neural network. When decoding each token :math:`y_t`, the gated recurrent neural network generates a set of weights :math:`W_S^t = \{W_S1^t, \dots, W_sT^t\}`, which are used to compute a weighted sum of the encoder vector. The weighted sum of the encoder vector is utilized to condition the generation of the token :math:`y_t`. + +The encoder part of the model is listed below. It calls `grumemory` to represent gated recurrent neural network. It is the recommended way of using recurrent neural network if the network architecture is simple, because it is faster than `recurrent_group`. We have implemented most of the commonly used recurrent neural network architectures, you can refer to :doc:`Layers <../trainer_config_helpers/layers>` for more details. + +We also project the encoder vector to `decoder_size` dimensional space, get the first instance of the backward recurrent network, and project it to `decoder_size` dimensional space:: + + + # Define the data layer of the source sentence. + src_word_id = data_layer(name='source_language_word', size=source_dict_dim) + # Calculate the word embedding of each word. + src_embedding = embedding_layer( + input=src_word_id, + size=word_vector_dim, + param_attr=ParamAttr(name='_source_language_embedding')) + # Apply forward recurrent neural network. + src_forward = grumemory(input=src_embedding, size=encoder_size) + # Apply backward recurrent neural network. reverse=True means backward recurrent neural network. + src_backward = grumemory(input=src_embedding, + size=encoder_size, + reverse=True) + # Mix the forward and backward parts of the recurrent neural network together. + encoded_vector = concat_layer(input=[src_forward, src_backward]) + + # Project encoding vector to decoder_size. + encoder_proj = mixed_layer(input = [full_matrix_projection(encoded_vector)], + size = decoder_size) + + # Compute the first instance of the backward RNN. + backward_first = first_seq(input=src_backward) + + # Project the first instance of backward RNN to decoder size. + decoder_boot = mixed_layer(input=[full_matrix_projection(backward_first)], size=decoder_size, act=TanhActivation()) + + +The decoder uses `recurrent_group` to define the recurrent neural network. The step and output functions are defined in `gru_decoder_with_attention`:: + + + trg_embedding = embedding_layer( + input=data_layer(name='target_language_word', + size=target_dict_dim), + size=word_vector_dim, + param_attr=ParamAttr(name='_target_language_embedding')) + # For decoder equipped with attention mechanism, in training, + # target embedding (the groudtruth) is the data input, + # while encoded source sequence is accessed to as an unbounded memory. + # StaticInput means the same value is utilized at different time steps. + # Otherwise, it is a sequence input. Inputs at different time steps are different. + # All sequence inputs should have the same length. + decoder = recurrent_group(name=decoder_group_name, + step=gru_decoder_with_attention, + input=[ + StaticInput(input=encoded_vector, + is_seq=True), + StaticInput(input=encoded_proj, + is_seq=True), + trg_embedding + ]) + + +The implementation of the step function is listed as below. First, it defines the *memory* of the decoder network. Then it defines attention, gated recurrent unit step function, and the output function:: + + + def gru_decoder_with_attention(enc_vec, enc_proj, current_word): + # Defines the memory of the decoder. + # The output of this memory is defined in gru_step. + # Notice that the name of gru_step should be the same as the name of this memory. + decoder_mem = memory(name='gru_decoder', + size=decoder_size, + boot_layer=decoder_boot) + # Compute attention weighted encoder vector. + context = simple_attention(encoded_sequence=enc_vec, + encoded_proj=enc_proj, + decoder_state=decoder_mem) + # Mix the current word embedding and the attention weighted encoder vector. + decoder_inputs = mixed_layer(inputs = [full_matrix_projection(context), + full_matrix_projection(current_word)], + size = decoder_size * 3) + # Define Gated recurrent unit recurrent neural network step function. + gru_step = gru_step_layer(name='gru_decoder', + input=decoder_inputs, + output_mem=decoder_mem, + size=decoder_size) + # Defines the output function. + out = mixed_layer(input=[full_matrix_projection(input=gru_step)], + size=target_dict_dim, + bias_attr=True, + act=SoftmaxActivation()) + return out + + +================= +Generate Sequence +================= +After training the model, we can use it to generate sequences. A common practice is to use *beam search* to generate sequences. The following code snippets defines a beam search algorithm. Notice that `beam_search` function assumes the output function of the `step` returns a softmax normalized probability vector of the next token. We made the following changes to the model. + +* use `GeneratedInput` for trg_embedding. `GeneratedInput` computes the embedding of the generated token at the last time step for the input at the current time step. +* use `beam_search` function. This function needs to set: + + - `id_input`: the integer ID of the data, used to identify the corresponding output in the generated files. + - `dict_file`: the dictionary file for converting word id to word. + - `bos_id`: the start token. Every sentence starts with the start token. + - `eos_id`: the end token. Every sentence ends with the end token. + - `beam_size`: the beam size used in beam search. + - `max_length`: the maximum length of the generated sentences. + - `result_file`: the path of the generation result file. + + +The code is listed below:: + + gen_inputs = [StaticInput(input=encoded_vector, + is_seq=True), + StaticInput(input=encoded_proj, + is_seq=True), ] + # In generation, decoder predicts a next target word based on + # the encoded source sequence and the last generated target word. + # The encoded source sequence (encoder's output) must be specified by + # StaticInput which is a read-only memory. + # Here, GeneratedInputs automatically fetchs the last generated word, + # which is initialized by a start mark, such as . + trg_embedding = GeneratedInput( + size=target_dict_dim, + embedding_name='_target_language_embedding', + embedding_size=word_vector_dim) + gen_inputs.append(trg_embedding) + beam_gen = beam_search(name=decoder_group_name, + step=gru_decoder_with_attention, + input=gen_inputs, + id_input=data_layer(name="sent_id", + size=1), + dict_file=trg_dict_path, + bos_id=0, # Beginnning token. + eos_id=1, # End of sentence token. + beam_size=beam_size, + max_length=max_length, + result_file=gen_trans_file) + outputs(beam_gen) + + +Notice that this generation technique is only useful for decoder like generation process. If you are working on sequence tagging tasks, please refer to :doc:`Semantic Role Labeling Demo <../../../demo/semantic_role_labeling>` for more details. + +The full configuration file is located at `./demo/seqToseq/seqToseq_net.py`. diff --git a/doc/ui/index.md b/doc/ui/index.md index 976d3382444..52b3dd46882 100644 --- a/doc/ui/index.md +++ b/doc/ui/index.md @@ -9,12 +9,13 @@ * [PyDataProviderWrapper](api/py_data_provider_wrapper.rst) * [Trainer Config Helpers](api/trainer_config_helpers/index.md) +* [Recurrent Neural Network Configuration](api/rnn/index.rst) ## Command Line Argument * [Use Case](cmd_argument/use_case.md) * [Argument Outline](cmd_argument/argument_outline.md) -* [Detail Description](cmd_argument/detail_introduction.md) +* [Detailed Descriptions of All Command Line Arguments](cmd_argument/detail_introduction.md) ## Predict -- GitLab