diff --git a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md index fa2b930be0d26d816566599cece8afbedc1157e0..6e5f77fec8a894c390ced8c93ee344fd8d27370e 100644 --- a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md +++ b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md @@ -60,6 +60,7 @@ 图3. 编码器-解码器框架 + #### 编码器 编码阶段分为三步: @@ -81,7 +82,7 @@ 机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是: 1. 每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)`$c$`、真实目标语言序列的第`$i$`个词`$u_i$`和`$i$`时刻RNN的隐层状态`$z_i$`,计算出下一个隐层状态`$z_{i+1}$`。计算公式如下: $$z_{i+1}=\phi_{\theta '} \left ( c,u_i,z_i \right )$$ -其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用[注意力机制](#注意力机制)时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。 +其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用注意力机制时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。 2. 将`$z_{i+1}$`通过`softmax`归一化,得到目标语言序列的第`$i+1$`个单词的概率分布`$p_{i+1}$`。概率分布公式如下: $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$ @@ -93,6 +94,7 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$ 机器翻译任务的生成过程,通俗来讲就是根据预先训练的模型来翻译源语言句子。生成过程中的解码阶段和上述训练过程的有所差异,具体介绍请见[柱搜索算法](#柱搜索算法)。 + ### 柱搜索算法 柱搜索([beam search](http://en.wikipedia.org/wiki/Beam_search))是一种启发式图搜索算法,用于在图或树中搜索有限集合中的最优扩展节点,通常用在解空间非常大的系统(如机器翻译、语音识别)中,原因是内存无法装下图或树中所有展开的解。如在机器翻译任务中希望翻译“`你好`”,就算目标语言字典中只有3个词(``, ``, `hello`),也可能生成无限句话(`hello`循环出现的次数不定),为了找到其中较好的翻译结果,我们可采用柱搜索算法。 diff --git a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md index 9900dfb9a67dc6f8940bd7dd3abfa15ac8a3488f..8477cf32146c33947ced447c8bdd287a3e1e71f5 100644 --- a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md +++ b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md @@ -149,6 +149,8 @@ def convolution_net(data, input_dim, class_dim, emb_dim, hid_dim): 网络的输入`input_dim`表示的是词典的大小,`class_dim`表示类别数。这里,我们使用[`sequence_conv_pool`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py) API实现了卷积和池化操作。 + + ### 栈式双向LSTM 栈式双向神经网络`stacked_lstm_net`的代码片段如下: diff --git a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md index 2c68cdac4f10319359b74bc92569dfd3f65380b5..904d99fe2ffc9ead69a86c9763568a5c098348d5 100644 --- a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md +++ b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md @@ -50,7 +50,7 @@ similarity: -0.0997506977351 ``` -以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[应用模型](#应用模型)中详细描述用法。 +以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[模型应用](#模型应用)中详细描述用法。 ## 模型概览 @@ -189,6 +189,7 @@ dream that one day 最后,每个输入会按其单词次在字典里的位置,转化成整数的索引序列,作为PaddlePaddle的输入。 + ## 编程实现 本配置的模型结构如下图所示: @@ -349,6 +350,7 @@ Step 20: Average Cost 5.766995 ... ``` + ## 模型应用 在模型训练后,我们可以用它做一些预测。 diff --git a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md index e6f89b23a95d1a07565f3e0a285e9c3f921930df..ac36c4ecf6b9b716fe5f0dbe2346e64918c22242 100644 --- a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md +++ b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md @@ -102,7 +102,7 @@ Softmax回归模型采用了最简单的两层神经网络,即只有输入层 池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,并且能够在一定程度上控制过拟合。通常在卷积层的后面会加上一个池化层。池化包括最大池化、平均池化等。其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框的数取最大值作为输出层,如图6所示。 -更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类](https://github.com/PaddlePaddle/book/blob/develop/image_classification/README.md)教程。 +更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类]( https://github.com/PaddlePaddle/book/tree/develop/03.image_classification )教程。 ### 常见激活函数介绍 - sigmoid激活函数: $ f(x) = sigmoid(x) = \frac{1}{1+e^{-x}} $ diff --git a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md index a2f30823a6fcd379f94e6e98d043b0d00681827f..84987ea5daee9abd0fe2fe71bdfde62ea3388ab5 100644 --- a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md +++ b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md @@ -149,7 +149,7 @@ python setup.py bdist_wheel pip install --upgrade dist/visualdl-*.whl ``` -如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/how_to_dev_frontend_en.md) +如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/develop/how_to_dev_frontend_cn.md) ## SDK diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 00f5e7fad2ef5d42eb0de9703389e910090d93c1..55153ecc3ed35688d8f861bde0f44ae2bf6a7111 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -11,7 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include #include "paddle/fluid/framework/lod_tensor.h" diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 7824ef2649af81a2390ff3bc537eb7c93c70e402..0f7d541c5edfc62e80cf50f83b491f06dcb42644 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -14,7 +14,7 @@ else fi PREFIX=inference-vis-demos%2F -URL_ROOT=http://paddlemodels.bj.bcebos.com/${PREFIX} +URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX} # download vis_demo data function download() { diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 66784f0b5149a7c479a90a407709d993f4a40a8b..31159a02592a2aff75f7ecf5be924989f0f47071 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname, const std::string& out_var_name) { VLOG(4) << "RequestSendHandler:" << varname; - // Async - if (!sync_mode_) { - rpc_server_->Profiler().OneStep(); - try { - executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), - scope); - } catch (std::exception& e) { - LOG(ERROR) << "async: run sub program error " << e.what(); - return false; - } - return true; - } - // Sync if (varname == BATCH_BARRIER_MESSAGE) { VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; @@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname, VLOG(3) << "sync: recv complete message"; rpc_server_->Complete(); } else { - VLOG(3) << "sync: received var_name: " << varname; - rpc_server_->WaitCond(kRequestSend); - VLOG(3) << "sync: processing received var: " << varname; - - if (invar == nullptr) { - LOG(FATAL) << "sync: Can not find server side var: " << varname; - return false; - } - if (invar->IsType()) { - std::unique_lock lock(mutex_sparse_vars_); - sparse_vars_.push_back(invar); + // Async + if (!sync_mode_) { + VLOG(3) << "async process var: " << varname; + rpc_server_->Profiler().OneStep(); + try { + executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), + scope); + } catch (std::exception& e) { + LOG(ERROR) << "async: run sub program error " << e.what(); + return false; + } + return true; + } else { // sync + rpc_server_->WaitCond(kRequestSend); + VLOG(3) << "sync: processing received var: " << varname; + + if (invar == nullptr) { + LOG(FATAL) << "sync: Can not find server side var: " << varname; + return false; + } + + if (invar->IsType()) { + std::unique_lock lock(mutex_sparse_vars_); + sparse_vars_.push_back(invar); + } } } return true; diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index f91236975d0cf0c89a464188bd6ea1b5b01e0f6d..104e160e2d7069ec247cc51e927ce8824f1b69e8 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE_EQ(b_dims[0], 1, "The first dimension of Input(Bias) should be 1."); - PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"), - "Do not support peephole yet."); - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, + auto use_peepholes = ctx->Attrs().Get("use_peepholes"); + PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size, "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes connection", - frame_size); + "7 * %d if enable peepholes connection or" + "4 * %d if disable peepholes", + frame_size, frame_size); framework::DDim out_dims({x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); @@ -232,16 +232,17 @@ class FuisonLSTMKernel : public framework::OpKernel { act_cand = act_functor(act_cand_str); \ } -#define INIT_BASE_INPUT_OUTPUT \ - auto* x = ctx.Input("X"); \ - auto* h0 = ctx.Input("H0"); \ - auto* c0 = ctx.Input("C0"); \ - auto* wx = ctx.Input("WeightX"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* bias = ctx.Input("Bias"); \ - auto* xx = ctx.Output("XX"); \ - auto* hidden_out = ctx.Output("Hidden"); \ - auto* cell_out = ctx.Output("Cell"); \ +#define INIT_BASE_INPUT_OUTPUT \ + auto* x = ctx.Input("X"); \ + auto* h0 = ctx.Input("H0"); \ + auto* c0 = ctx.Input("C0"); \ + auto* wx = ctx.Input("WeightX"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* bias = ctx.Input("Bias"); \ + auto* xx = ctx.Output("XX"); \ + auto* hidden_out = ctx.Output("Hidden"); \ + auto* cell_out = ctx.Output("Cell"); \ + bool use_peepholes = ctx.Attr("use_peepholes"); \ bool is_reverse = ctx.Attr("is_reverse"); #define INIT_BASE_SIZES \ @@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel { const T* x_data = x->data(); const T* h0_data = h0 ? h0->data() : nullptr; const T* c0_data = c0 ? c0->data() : nullptr; + const T* bias_data = bias->data(); + const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc const T* wx_data = wx->data(); const T* wh_data = wh->data(); + T* xx_data = xx->mutable_data(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace()); T* cell_out_data = cell_out->mutable_data(ctx.GetPlace()); + // use local variable + framework::DDim check_dims({3, D}); + Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct + auto checked_cell_data = + checked_cell.mutable_data(check_dims, ctx.GetPlace()); + auto blas = math::GetBlas(ctx); math::FCCompute(blas, total_T, D4, M, x_data, wx_data, xx_data, bias->data()); @@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel { int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; const T* prev_c_data = nullptr; const T* prev_h_data = nullptr; + int tstart = 0; if (h0_data) { prev_h_data = h0_data + bid * D; prev_c_data = c0_data + bid * D; } else { - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, xx_data + D, xx_data + D); + // If step == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros. Then W_h * H_t-1 can be skipped + + // ~C_t act_cand(D, xx_data, xx_data); - // cell out= input*tilde + if (use_peepholes) { + // I_t, F_t + act_gate(D2, xx_data + D, xx_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, xx_data + D, xx_data + D); + } + // C_t = I_t * ~C_t blas.VMUL(D, xx_data, xx_data + D, cell_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); + blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); + // O_t + act_gate(D, xx_data + D3, xx_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev prev_h_data = hidden_out_data; prev_c_data = cell_out_data; - tstart = 1; + tstart = 1; move_step(); } + for (int step = tstart; step < seq_len; ++step) { + // + W_h * H_t-1 blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1), prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4); - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, xx_data + D, xx_data + D); + // ~C_t act_cand(D, xx_data, xx_data); - // a = forget * prev_cell + if (use_peepholes) { + // + W_ic|W_fc * C_t-1 for peephole connection + blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); + blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); + blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D); + // I_t, F_t + act_gate(D2, xx_data + D, xx_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, xx_data + D, xx_data + D); + } + + // F_t * C_t-1 blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); - - // b = input * tilde + // I_t * ~C_t blas.VMUL(D, xx_data, xx_data + D, xx_data + D); - - // cell out= a+b + // C_t = F_t * C_t-1 + I_t * ~C_t blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2); + blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3); + // O_t + act_gate(D, xx_data + D3, xx_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cell_out_data, xx_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); // prev @@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel { prev_c_data = cell_out_data; move_step(); - } - } + } // for each step in batch + } // for each batch } void BatchCompute(const framework::ExecutionContext& ctx) const { using DeviceContext = platform::CPUDeviceContext; INIT_BASE_INPUT_OUTPUT - if (x->lod()[0].size() == 2) { + if (x->lod()[0].size() == 2) { // batch size == 1 SeqCompute(ctx); return; } @@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel { const T* x_data = x->data(); const T* wx_data = wx->data(); const T* wh_data = wh->data(); + const T* bias_data = bias->data(); + const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc auto place = ctx.GetPlace(); T* xx_data = xx->mutable_data(place); T* batched_input_data = batched_input->mutable_data(place); @@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel { hidden_out->mutable_data(place); cell_out->mutable_data(place); + // use local variable + framework::DDim check_dims({3, D}); + Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct + auto checked_cell_data = + checked_cell.mutable_data(check_dims, ctx.GetPlace()); + math::LoDTensor2BatchFunctor to_batch; auto& dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); @@ -396,17 +454,27 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_h0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D}); + T* prev_batch_h_data = nullptr; + T* prev_batch_c_data = nullptr; + T* cur_batch_in_data = batched_input_data; + T* cur_batch_h_out_data = batched_h_out_data; + T* cur_batch_c_out_data = batched_c_out_data; + + auto move_step = [&](int bs) { + cur_batch_in_data += bs * D4; + cur_batch_c_out_data += bs * D; + cur_batch_h_out_data += bs * D; + }; + int tstart = 0; - T* prev_h_data = nullptr; - T* prev_c_data = nullptr; if (h0) { // reorder h0, c0 T* reordered_h0_data = reordered_h0->mutable_data(place); T* reordered_c0_data = reordered_c0->mutable_data(place); const T* h0_data = h0->data(); const T* c0_data = c0->data(); - prev_h_data = reordered_h0_data; - prev_c_data = reordered_c0_data; + prev_batch_h_data = reordered_h0_data; + prev_batch_c_data = reordered_c0_data; size_t sz = sizeof(T) * D; for (int i = 0; i < max_bs; ++i) { std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); @@ -415,71 +483,122 @@ class FuisonLSTMKernel : public framework::OpKernel { reordered_c0_data += D; } } else { - // compute without h0, c0 - T* cur_in_data = batched_input_data; - T* cur_h_out_data = batched_h_out_data; - T* cur_c_out_data = batched_c_out_data; - // W_ch, W_ih, W_fh, W_oh - for (int i = 0; i < max_bs; ++i) { - act_gate(D3, cur_in_data + D, cur_in_data + D); + // Compute with no H0/C0 + T* cur_in_data = cur_batch_in_data; + T* cur_c_out_data = cur_batch_c_out_data; + T* cur_h_out_data = cur_batch_h_out_data; + + // If step == 0 and there is no initialized hidden state, that is to say + // the H0 is zeros. Then W_h * H_t-1 can be skiped + + for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch + // ~C_t act_cand(D, cur_in_data, cur_in_data); - // cell out= input*tilde + + if (use_peepholes) { + // I_t, F_t + act_gate(D2, cur_in_data + D, cur_in_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, cur_in_data + D, cur_in_data + D); + } + + // C_t = I_t * ~C_t blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); + blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, + cur_in_data + D3); + // O_t + act_gate(D, cur_in_data + D3, cur_in_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cur_c_out_data, cur_in_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - // add offset + // move to next data in the same batch cur_in_data += D4; cur_c_out_data += D; cur_h_out_data += D; } + + // move to data for next timestep + prev_batch_h_data = cur_batch_h_out_data; + prev_batch_c_data = cur_batch_c_out_data; + move_step(max_bs); tstart = 1; - prev_h_data = batched_h_out_data; - prev_c_data = batched_c_out_data; } - // Then start from next + const auto& batch_starts = batched_lod[0]; const int max_seq_len = batch_starts.size() - 1; - const int offset = tstart * max_bs * D; - batched_input_data = batched_input_data + offset * 4; - batched_h_out_data = batched_h_out_data + offset; - batched_c_out_data = batched_c_out_data + offset; for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; + // + W_h * H_t-1 blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast(1), - prev_h_data, D, wh_data, D4, static_cast(1), - batched_input_data, D4); - - T* cur_in_data = batched_input_data; - T* cur_prev_c_data = prev_c_data; - T* cur_c_out_data = batched_c_out_data; - T* cur_h_out_data = batched_h_out_data; - for (int i = 0; i < cur_bs; ++i) { - // W_ch, W_ih, W_fh, W_oh - act_gate(D3, cur_in_data + D, cur_in_data + D); + prev_batch_h_data, D, wh_data, D4, static_cast(1), + cur_batch_in_data, D4); + + T* cur_in_data = cur_batch_in_data; + T* cur_c_out_data = cur_batch_c_out_data; + T* cur_h_out_data = cur_batch_h_out_data; + T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0 + T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0 + auto next_data_in_batch = [&]() { + cur_in_data += D4; + cur_c_out_data += D; + cur_h_out_data += D; + prev_c_data = prev_c_data ? prev_c_data + D : nullptr; + prev_h_data = prev_h_data ? prev_h_data + D : nullptr; + }; + + for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch + // ~C_t act_cand(D, cur_in_data, cur_in_data); - // a = forget * prev_cell - blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2); - // b = input * tilde + + if (use_peepholes) { + // + W_ic|W_fc * C_t-1 for peephole connection + blas.VMUL(D, wc_data, prev_c_data, checked_cell_data); + blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D); + blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D); + // I_t, F_t + act_gate(D2, cur_in_data + D, cur_in_data + D); + } else { + // I_t, F_t, O_t + act_gate(D3, cur_in_data + D, cur_in_data + D); + } + + // F_t * C_t-1 + blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2); + // I_t * ~C_t blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); - // cell out= a+b + // C_t = F_t * C_t-1 + I_t * ~C_t blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); + + if (use_peepholes) { + // + W_oc * C_t for peephole connection + blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2); + blas.VADD(D, cur_in_data + D3, checked_cell_data + D2, + cur_in_data + D3); + // O_t + act_gate(D, cur_in_data + D3, cur_in_data + D3); + } + // hidden out= act_state(cellout) * outgate act_cell(D, cur_c_out_data, cur_in_data + D2); + // H_t = O_t * act_state(C_t) blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); - cur_in_data += D4; - cur_prev_c_data += D; - cur_c_out_data += D; - cur_h_out_data += D; + // move to next data in same batch + next_data_in_batch(); } - - prev_c_data = batched_c_out_data; - prev_h_data = batched_h_out_data; - batched_c_out_data = cur_c_out_data; - batched_h_out_data = cur_h_out_data; - batched_input_data = cur_in_data; + // move to data for next timestep + prev_batch_h_data = cur_batch_h_out_data; + prev_batch_c_data = cur_batch_c_out_data; + move_step(cur_bs); } math::Batch2LoDTensorFunctor to_seq; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a0d92fd1462acb18cdb2463b51138c9ff33b08a8..d8c7cc08b652f91456f557b0296e85b9aebc9dd0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3546,11 +3546,6 @@ def topk(input, k, name=None): top5_values, top5_indices = layers.topk(input, k=5) """ - shape = input.shape - if k < 1 or k >= shape[-1]: - raise ValueError("k must be greater than 0 and less than %d." % - (shape[-1])) - helper = LayerHelper("top_k", **locals()) values = helper.create_tmp_variable(dtype=input.dtype) indices = helper.create_tmp_variable(dtype="int64") diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py index 1f1eb37667e304351a6a85edde09e7da32cf1630..4767e9433ea74d5da83867d646f2a63c9a092668 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py @@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest): self.act_cell = 'tanh' self.act_cand = 'tanh' self.use_peepholes = False + self.use_seq = False self.set_conf() T = sum(self.lod[0]) @@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest): } self.attrs = { 'use_peepholes': self.use_peepholes, + 'use_seq': self.use_seq, 'is_reverse': self.is_reverse, 'gate_activation': self.act_gate, 'cell_activation': self.act_cell, @@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp): self.D = 16 +class TestFusionLSTMOpPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + + +class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.is_reverse = True + + +class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp): + def set_conf(self): + self.use_peepholes = True + self.lod = [[3]] + self.D = 16 + + +class TestFusionLSTMOpSeqInit(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.has_initial_state = True + + +class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.is_reverse = True + + +class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.has_initial_state = True + self.is_reverse = True + + +class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + + +class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + self.has_initial_state = True + + +class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp): + def set_conf(self): + self.use_seq = True + self.use_peepholes = True + self.is_reverse = True + + if __name__ == '__main__': unittest.main()