diff --git a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
index fa2b930be0d26d816566599cece8afbedc1157e0..6e5f77fec8a894c390ced8c93ee344fd8d27370e 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
@@ -60,6 +60,7 @@
图3. 编码器-解码器框架
+
#### 编码器
编码阶段分为三步:
@@ -81,7 +82,7 @@
机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是:
1. 每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)`$c$`、真实目标语言序列的第`$i$`个词`$u_i$`和`$i$`时刻RNN的隐层状态`$z_i$`,计算出下一个隐层状态`$z_{i+1}$`。计算公式如下:
$$z_{i+1}=\phi_{\theta '} \left ( c,u_i,z_i \right )$$
-其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用[注意力机制](#注意力机制)时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。
+其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用注意力机制时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。
2. 将`$z_{i+1}$`通过`softmax`归一化,得到目标语言序列的第`$i+1$`个单词的概率分布`$p_{i+1}$`。概率分布公式如下:
$$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
@@ -93,6 +94,7 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
机器翻译任务的生成过程,通俗来讲就是根据预先训练的模型来翻译源语言句子。生成过程中的解码阶段和上述训练过程的有所差异,具体介绍请见[柱搜索算法](#柱搜索算法)。
+
### 柱搜索算法
柱搜索([beam search](http://en.wikipedia.org/wiki/Beam_search))是一种启发式图搜索算法,用于在图或树中搜索有限集合中的最优扩展节点,通常用在解空间非常大的系统(如机器翻译、语音识别)中,原因是内存无法装下图或树中所有展开的解。如在机器翻译任务中希望翻译“`你好`”,就算目标语言字典中只有3个词(``, ``, `hello`),也可能生成无限句话(`hello`循环出现的次数不定),为了找到其中较好的翻译结果,我们可采用柱搜索算法。
diff --git a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
index 9900dfb9a67dc6f8940bd7dd3abfa15ac8a3488f..8477cf32146c33947ced447c8bdd287a3e1e71f5 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
@@ -149,6 +149,8 @@ def convolution_net(data, input_dim, class_dim, emb_dim, hid_dim):
网络的输入`input_dim`表示的是词典的大小,`class_dim`表示类别数。这里,我们使用[`sequence_conv_pool`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py) API实现了卷积和池化操作。
+
+
### 栈式双向LSTM
栈式双向神经网络`stacked_lstm_net`的代码片段如下:
diff --git a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
index 2c68cdac4f10319359b74bc92569dfd3f65380b5..904d99fe2ffc9ead69a86c9763568a5c098348d5 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
@@ -50,7 +50,7 @@ similarity: -0.0997506977351
```
-以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[应用模型](#应用模型)中详细描述用法。
+以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[模型应用](#模型应用)中详细描述用法。
## 模型概览
@@ -189,6 +189,7 @@ dream that one day
最后,每个输入会按其单词次在字典里的位置,转化成整数的索引序列,作为PaddlePaddle的输入。
+
## 编程实现
本配置的模型结构如下图所示:
@@ -349,6 +350,7 @@ Step 20: Average Cost 5.766995
...
```
+
## 模型应用
在模型训练后,我们可以用它做一些预测。
diff --git a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
index e6f89b23a95d1a07565f3e0a285e9c3f921930df..ac36c4ecf6b9b716fe5f0dbe2346e64918c22242 100644
--- a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
@@ -102,7 +102,7 @@ Softmax回归模型采用了最简单的两层神经网络,即只有输入层
池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,并且能够在一定程度上控制过拟合。通常在卷积层的后面会加上一个池化层。池化包括最大池化、平均池化等。其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框的数取最大值作为输出层,如图6所示。
-更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类](https://github.com/PaddlePaddle/book/blob/develop/image_classification/README.md)教程。
+更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类]( https://github.com/PaddlePaddle/book/tree/develop/03.image_classification )教程。
### 常见激活函数介绍
- sigmoid激活函数: $ f(x) = sigmoid(x) = \frac{1}{1+e^{-x}} $
diff --git a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
index a2f30823a6fcd379f94e6e98d043b0d00681827f..84987ea5daee9abd0fe2fe71bdfde62ea3388ab5 100644
--- a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
+++ b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
@@ -149,7 +149,7 @@ python setup.py bdist_wheel
pip install --upgrade dist/visualdl-*.whl
```
-如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/how_to_dev_frontend_en.md)
+如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/develop/how_to_dev_frontend_cn.md)
## SDK
diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
index 00f5e7fad2ef5d42eb0de9703389e910090d93c1..55153ecc3ed35688d8f861bde0f44ae2bf6a7111 100644
--- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include
#include "paddle/fluid/framework/lod_tensor.h"
diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh
index 7824ef2649af81a2390ff3bc537eb7c93c70e402..0f7d541c5edfc62e80cf50f83b491f06dcb42644 100755
--- a/paddle/fluid/inference/api/demo_ci/run.sh
+++ b/paddle/fluid/inference/api/demo_ci/run.sh
@@ -14,7 +14,7 @@ else
fi
PREFIX=inference-vis-demos%2F
-URL_ROOT=http://paddlemodels.bj.bcebos.com/${PREFIX}
+URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX}
# download vis_demo data
function download() {
diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc
index 66784f0b5149a7c479a90a407709d993f4a40a8b..31159a02592a2aff75f7ecf5be924989f0f47071 100644
--- a/paddle/fluid/operators/distributed/request_handler_impl.cc
+++ b/paddle/fluid/operators/distributed/request_handler_impl.cc
@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname;
- // Async
- if (!sync_mode_) {
- rpc_server_->Profiler().OneStep();
- try {
- executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
- scope);
- } catch (std::exception& e) {
- LOG(ERROR) << "async: run sub program error " << e.what();
- return false;
- }
- return true;
- }
-
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
- VLOG(3) << "sync: received var_name: " << varname;
- rpc_server_->WaitCond(kRequestSend);
- VLOG(3) << "sync: processing received var: " << varname;
-
- if (invar == nullptr) {
- LOG(FATAL) << "sync: Can not find server side var: " << varname;
- return false;
- }
- if (invar->IsType()) {
- std::unique_lock lock(mutex_sparse_vars_);
- sparse_vars_.push_back(invar);
+ // Async
+ if (!sync_mode_) {
+ VLOG(3) << "async process var: " << varname;
+ rpc_server_->Profiler().OneStep();
+ try {
+ executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
+ scope);
+ } catch (std::exception& e) {
+ LOG(ERROR) << "async: run sub program error " << e.what();
+ return false;
+ }
+ return true;
+ } else { // sync
+ rpc_server_->WaitCond(kRequestSend);
+ VLOG(3) << "sync: processing received var: " << varname;
+
+ if (invar == nullptr) {
+ LOG(FATAL) << "sync: Can not find server side var: " << varname;
+ return false;
+ }
+
+ if (invar->IsType()) {
+ std::unique_lock lock(mutex_sparse_vars_);
+ sparse_vars_.push_back(invar);
+ }
}
}
return true;
diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index f91236975d0cf0c89a464188bd6ea1b5b01e0f6d..104e160e2d7069ec247cc51e927ce8824f1b69e8 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
- PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"),
- "Do not support peephole yet.");
- PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
+ auto use_peepholes = ctx->Attrs().Get("use_peepholes");
+ PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size,
"The second dimension of Input(Bias) should be "
- "4 * %d if disable peepholes connection",
- frame_size);
+ "7 * %d if enable peepholes connection or"
+ "4 * %d if disable peepholes",
+ frame_size, frame_size);
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
@@ -232,16 +232,17 @@ class FuisonLSTMKernel : public framework::OpKernel {
act_cand = act_functor(act_cand_str); \
}
-#define INIT_BASE_INPUT_OUTPUT \
- auto* x = ctx.Input("X"); \
- auto* h0 = ctx.Input("H0"); \
- auto* c0 = ctx.Input("C0"); \
- auto* wx = ctx.Input("WeightX"); \
- auto* wh = ctx.Input("WeightH"); \
- auto* bias = ctx.Input("Bias"); \
- auto* xx = ctx.Output("XX"); \
- auto* hidden_out = ctx.Output("Hidden"); \
- auto* cell_out = ctx.Output("Cell"); \
+#define INIT_BASE_INPUT_OUTPUT \
+ auto* x = ctx.Input("X"); \
+ auto* h0 = ctx.Input("H0"); \
+ auto* c0 = ctx.Input("C0"); \
+ auto* wx = ctx.Input("WeightX"); \
+ auto* wh = ctx.Input("WeightH"); \
+ auto* bias = ctx.Input("Bias"); \
+ auto* xx = ctx.Output("XX"); \
+ auto* hidden_out = ctx.Output("Hidden"); \
+ auto* cell_out = ctx.Output("Cell"); \
+ bool use_peepholes = ctx.Attr("use_peepholes"); \
bool is_reverse = ctx.Attr("is_reverse");
#define INIT_BASE_SIZES \
@@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel {
const T* x_data = x->data();
const T* h0_data = h0 ? h0->data() : nullptr;
const T* c0_data = c0 ? c0->data() : nullptr;
+ const T* bias_data = bias->data();
+ const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
const T* wx_data = wx->data();
const T* wh_data = wh->data();
+
T* xx_data = xx->mutable_data(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data(ctx.GetPlace());
+ // use local variable
+ framework::DDim check_dims({3, D});
+ Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
+ auto checked_cell_data =
+ checked_cell.mutable_data(check_dims, ctx.GetPlace());
+
auto blas = math::GetBlas(ctx);
math::FCCompute(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data());
@@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel {
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_c_data = nullptr;
const T* prev_h_data = nullptr;
+
int tstart = 0;
if (h0_data) {
prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D;
} else {
- // W_ch, W_ih, W_fh, W_oh
- act_gate(D3, xx_data + D, xx_data + D);
+ // If step == 0 and there is no initialized hidden state, that is to say
+ // the H0 is zeros. Then W_h * H_t-1 can be skipped
+
+ // ~C_t
act_cand(D, xx_data, xx_data);
- // cell out= input*tilde
+ if (use_peepholes) {
+ // I_t, F_t
+ act_gate(D2, xx_data + D, xx_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, xx_data + D, xx_data + D);
+ }
+ // C_t = I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
+
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
+ blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
+ // O_t
+ act_gate(D, xx_data + D3, xx_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
- tstart = 1;
+ tstart = 1;
move_step();
}
+
for (int step = tstart; step < seq_len; ++step) {
+ // + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1),
prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4);
- // W_ch, W_ih, W_fh, W_oh
- act_gate(D3, xx_data + D, xx_data + D);
+ // ~C_t
act_cand(D, xx_data, xx_data);
- // a = forget * prev_cell
+ if (use_peepholes) {
+ // + W_ic|W_fc * C_t-1 for peephole connection
+ blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
+ blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
+ blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D);
+ // I_t, F_t
+ act_gate(D2, xx_data + D, xx_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, xx_data + D, xx_data + D);
+ }
+
+ // F_t * C_t-1
blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2);
-
- // b = input * tilde
+ // I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
-
- // cell out= a+b
+ // C_t = F_t * C_t-1 + I_t * ~C_t
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
+ blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
+ // O_t
+ act_gate(D, xx_data + D3, xx_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
@@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel {
prev_c_data = cell_out_data;
move_step();
- }
- }
+ } // for each step in batch
+ } // for each batch
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT
- if (x->lod()[0].size() == 2) {
+ if (x->lod()[0].size() == 2) { // batch size == 1
SeqCompute(ctx);
return;
}
@@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel {
const T* x_data = x->data();
const T* wx_data = wx->data();
const T* wh_data = wh->data();
+ const T* bias_data = bias->data();
+ const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
auto place = ctx.GetPlace();
T* xx_data = xx->mutable_data(place);
T* batched_input_data = batched_input->mutable_data(place);
@@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel {
hidden_out->mutable_data(place);
cell_out->mutable_data(place);
+ // use local variable
+ framework::DDim check_dims({3, D});
+ Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
+ auto checked_cell_data =
+ checked_cell.mutable_data(check_dims, ctx.GetPlace());
+
math::LoDTensor2BatchFunctor to_batch;
auto& dev_ctx = ctx.template device_context();
auto blas = math::GetBlas(dev_ctx);
@@ -396,17 +454,27 @@ class FuisonLSTMKernel : public framework::OpKernel {
reordered_h0->Resize({max_bs, D});
reordered_c0->Resize({max_bs, D});
+ T* prev_batch_h_data = nullptr;
+ T* prev_batch_c_data = nullptr;
+ T* cur_batch_in_data = batched_input_data;
+ T* cur_batch_h_out_data = batched_h_out_data;
+ T* cur_batch_c_out_data = batched_c_out_data;
+
+ auto move_step = [&](int bs) {
+ cur_batch_in_data += bs * D4;
+ cur_batch_c_out_data += bs * D;
+ cur_batch_h_out_data += bs * D;
+ };
+
int tstart = 0;
- T* prev_h_data = nullptr;
- T* prev_c_data = nullptr;
if (h0) {
// reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data(place);
T* reordered_c0_data = reordered_c0->mutable_data(place);
const T* h0_data = h0->data();
const T* c0_data = c0->data();
- prev_h_data = reordered_h0_data;
- prev_c_data = reordered_c0_data;
+ prev_batch_h_data = reordered_h0_data;
+ prev_batch_c_data = reordered_c0_data;
size_t sz = sizeof(T) * D;
for (int i = 0; i < max_bs; ++i) {
std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
@@ -415,71 +483,122 @@ class FuisonLSTMKernel : public framework::OpKernel {
reordered_c0_data += D;
}
} else {
- // compute without h0, c0
- T* cur_in_data = batched_input_data;
- T* cur_h_out_data = batched_h_out_data;
- T* cur_c_out_data = batched_c_out_data;
- // W_ch, W_ih, W_fh, W_oh
- for (int i = 0; i < max_bs; ++i) {
- act_gate(D3, cur_in_data + D, cur_in_data + D);
+ // Compute with no H0/C0
+ T* cur_in_data = cur_batch_in_data;
+ T* cur_c_out_data = cur_batch_c_out_data;
+ T* cur_h_out_data = cur_batch_h_out_data;
+
+ // If step == 0 and there is no initialized hidden state, that is to say
+ // the H0 is zeros. Then W_h * H_t-1 can be skiped
+
+ for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch
+ // ~C_t
act_cand(D, cur_in_data, cur_in_data);
- // cell out= input*tilde
+
+ if (use_peepholes) {
+ // I_t, F_t
+ act_gate(D2, cur_in_data + D, cur_in_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, cur_in_data + D, cur_in_data + D);
+ }
+
+ // C_t = I_t * ~C_t
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data);
+
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
+ blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
+ cur_in_data + D3);
+ // O_t
+ act_gate(D, cur_in_data + D3, cur_in_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
- // add offset
+ // move to next data in the same batch
cur_in_data += D4;
cur_c_out_data += D;
cur_h_out_data += D;
}
+
+ // move to data for next timestep
+ prev_batch_h_data = cur_batch_h_out_data;
+ prev_batch_c_data = cur_batch_c_out_data;
+ move_step(max_bs);
tstart = 1;
- prev_h_data = batched_h_out_data;
- prev_c_data = batched_c_out_data;
}
- // Then start from next
+
const auto& batch_starts = batched_lod[0];
const int max_seq_len = batch_starts.size() - 1;
- const int offset = tstart * max_bs * D;
- batched_input_data = batched_input_data + offset * 4;
- batched_h_out_data = batched_h_out_data + offset;
- batched_c_out_data = batched_c_out_data + offset;
for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step];
+ // + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast(1),
- prev_h_data, D, wh_data, D4, static_cast(1),
- batched_input_data, D4);
-
- T* cur_in_data = batched_input_data;
- T* cur_prev_c_data = prev_c_data;
- T* cur_c_out_data = batched_c_out_data;
- T* cur_h_out_data = batched_h_out_data;
- for (int i = 0; i < cur_bs; ++i) {
- // W_ch, W_ih, W_fh, W_oh
- act_gate(D3, cur_in_data + D, cur_in_data + D);
+ prev_batch_h_data, D, wh_data, D4, static_cast(1),
+ cur_batch_in_data, D4);
+
+ T* cur_in_data = cur_batch_in_data;
+ T* cur_c_out_data = cur_batch_c_out_data;
+ T* cur_h_out_data = cur_batch_h_out_data;
+ T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0
+ T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0
+ auto next_data_in_batch = [&]() {
+ cur_in_data += D4;
+ cur_c_out_data += D;
+ cur_h_out_data += D;
+ prev_c_data = prev_c_data ? prev_c_data + D : nullptr;
+ prev_h_data = prev_h_data ? prev_h_data + D : nullptr;
+ };
+
+ for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch
+ // ~C_t
act_cand(D, cur_in_data, cur_in_data);
- // a = forget * prev_cell
- blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2);
- // b = input * tilde
+
+ if (use_peepholes) {
+ // + W_ic|W_fc * C_t-1 for peephole connection
+ blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
+ blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
+ blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D);
+ // I_t, F_t
+ act_gate(D2, cur_in_data + D, cur_in_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, cur_in_data + D, cur_in_data + D);
+ }
+
+ // F_t * C_t-1
+ blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2);
+ // I_t * ~C_t
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
- // cell out= a+b
+ // C_t = F_t * C_t-1 + I_t * ~C_t
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
+
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
+ blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
+ cur_in_data + D3);
+ // O_t
+ act_gate(D, cur_in_data + D3, cur_in_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
- cur_in_data += D4;
- cur_prev_c_data += D;
- cur_c_out_data += D;
- cur_h_out_data += D;
+ // move to next data in same batch
+ next_data_in_batch();
}
-
- prev_c_data = batched_c_out_data;
- prev_h_data = batched_h_out_data;
- batched_c_out_data = cur_c_out_data;
- batched_h_out_data = cur_h_out_data;
- batched_input_data = cur_in_data;
+ // move to data for next timestep
+ prev_batch_h_data = cur_batch_h_out_data;
+ prev_batch_c_data = cur_batch_c_out_data;
+ move_step(cur_bs);
}
math::Batch2LoDTensorFunctor to_seq;
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index a0d92fd1462acb18cdb2463b51138c9ff33b08a8..d8c7cc08b652f91456f557b0296e85b9aebc9dd0 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5)
"""
- shape = input.shape
- if k < 1 or k >= shape[-1]:
- raise ValueError("k must be greater than 0 and less than %d." %
- (shape[-1]))
-
helper = LayerHelper("top_k", **locals())
values = helper.create_tmp_variable(dtype=input.dtype)
indices = helper.create_tmp_variable(dtype="int64")
diff --git a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
index 1f1eb37667e304351a6a85edde09e7da32cf1630..4767e9433ea74d5da83867d646f2a63c9a092668 100644
--- a/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+++ b/python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.use_peepholes = False
+ self.use_seq = False
self.set_conf()
T = sum(self.lod[0])
@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
}
self.attrs = {
'use_peepholes': self.use_peepholes,
+ 'use_seq': self.use_seq,
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
self.D = 16
+class TestFusionLSTMOpPeepholes(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_peepholes = True
+
+
+class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_peepholes = True
+ self.has_initial_state = True
+
+
+class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_peepholes = True
+ self.is_reverse = True
+
+
+class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_peepholes = True
+ self.lod = [[3]]
+ self.D = 16
+
+
+class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.has_initial_state = True
+
+
+class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.is_reverse = True
+
+
+class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.has_initial_state = True
+ self.is_reverse = True
+
+
+class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.use_peepholes = True
+
+
+class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.use_peepholes = True
+ self.has_initial_state = True
+
+
+class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
+ def set_conf(self):
+ self.use_seq = True
+ self.use_peepholes = True
+ self.is_reverse = True
+
+
if __name__ == '__main__':
unittest.main()