Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
2dc23ffa
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2dc23ffa
编写于
9月 05, 2018
作者:
L
luotao1
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into multi-thread2
上级
39ed1487
04272c0d
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
293 addition
and
108 deletion
+293
-108
doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
...s/beginners_guide/basics/machine_translation/README.cn.md
+3
-1
doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
.../beginners_guide/basics/understand_sentiment/README.cn.md
+2
-0
doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
...uid/new_docs/beginners_guide/basics/word2vec/README.cn.md
+3
-1
doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
...beginners_guide/quick_start/recognize_digits/README.cn.md
+1
-1
doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
+1
-1
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+0
-1
paddle/fluid/inference/api/demo_ci/run.sh
paddle/fluid/inference/api/demo_ci/run.sh
+1
-1
paddle/fluid/operators/distributed/request_handler_impl.cc
paddle/fluid/operators/distributed/request_handler_impl.cc
+25
-24
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+192
-73
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+0
-5
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+65
-0
未找到文件。
doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
浏览文件 @
2dc23ffa
...
@@ -60,6 +60,7 @@
...
@@ -60,6 +60,7 @@
图3. 编码器-解码器框架
图3. 编码器-解码器框架
</div>
</div>
<a
name=
"编码器"
></a>
#### 编码器
#### 编码器
编码阶段分为三步:
编码阶段分为三步:
...
@@ -81,7 +82,7 @@
...
@@ -81,7 +82,7 @@
机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是:
机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是:
1.
每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)
`$c$`
、真实目标语言序列的第
`$i$`
个词
`$u_i$`
和
`$i$`
时刻RNN的隐层状态
`$z_i$`
,计算出下一个隐层状态
`$z_{i+1}$`
。计算公式如下:
1.
每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)
`$c$`
、真实目标语言序列的第
`$i$`
个词
`$u_i$`
和
`$i$`
时刻RNN的隐层状态
`$z_i$`
,计算出下一个隐层状态
`$z_{i+1}$`
。计算公式如下:
$$z_{i+1}=
\p
hi_{
\t
heta '}
\l
eft ( c,u_i,z_i
\r
ight )$$
$$z_{i+1}=
\p
hi_{
\t
heta '}
\l
eft ( c,u_i,z_i
\r
ight )$$
其中
`$\phi _{\theta '}$`
是一个非线性激活函数;
`$c=q\mathbf{h}$`
是源语言句子的上下文向量,在不使用
[
注意力机制
](
#注意力机制
)
时,如果
[
编码器
](
#编码器
)
的输出是源语言句子编码后的最后一个元素,则可以定义
`$c=h_T$`
;
`$u_i$`
是目标语言序列的第
`$i$`
个单词,
`$u_0$`
是目标语言序列的开始标记
`<s>`
,表示解码开始;
`$z_i$`
是
`$i$`
时刻解码RNN的隐层状态,
`$z_0$`
是一个全零的向量。
其中
`$\phi _{\theta '}$`
是一个非线性激活函数;
`$c=q\mathbf{h}$`
是源语言句子的上下文向量,在不使用
注意力机制
时,如果
[
编码器
](
#编码器
)
的输出是源语言句子编码后的最后一个元素,则可以定义
`$c=h_T$`
;
`$u_i$`
是目标语言序列的第
`$i$`
个单词,
`$u_0$`
是目标语言序列的开始标记
`<s>`
,表示解码开始;
`$z_i$`
是
`$i$`
时刻解码RNN的隐层状态,
`$z_0$`
是一个全零的向量。
2.
将
`$z_{i+1}$`
通过
`softmax`
归一化,得到目标语言序列的第
`$i+1$`
个单词的概率分布
`$p_{i+1}$`
。概率分布公式如下:
2.
将
`$z_{i+1}$`
通过
`softmax`
归一化,得到目标语言序列的第
`$i+1$`
个单词的概率分布
`$p_{i+1}$`
。概率分布公式如下:
$$p
\l
eft ( u_{i+1}|u_{
<
i+1},
\m
athbf{x}
\r
ight )=softmax(W_sz_{i+1}+b_z)$$
$$p
\l
eft ( u_{i+1}|u_{
<
i+1},
\m
athbf{x}
\r
ight )=softmax(W_sz_{i+1}+b_z)$$
...
@@ -93,6 +94,7 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
...
@@ -93,6 +94,7 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
机器翻译任务的生成过程,通俗来讲就是根据预先训练的模型来翻译源语言句子。生成过程中的解码阶段和上述训练过程的有所差异,具体介绍请见
[
柱搜索算法
](
#柱搜索算法
)
。
机器翻译任务的生成过程,通俗来讲就是根据预先训练的模型来翻译源语言句子。生成过程中的解码阶段和上述训练过程的有所差异,具体介绍请见
[
柱搜索算法
](
#柱搜索算法
)
。
<a
name=
"柱搜索算法"
></a>
### 柱搜索算法
### 柱搜索算法
柱搜索(
[
beam search
](
http://en.wikipedia.org/wiki/Beam_search
)
)是一种启发式图搜索算法,用于在图或树中搜索有限集合中的最优扩展节点,通常用在解空间非常大的系统(如机器翻译、语音识别)中,原因是内存无法装下图或树中所有展开的解。如在机器翻译任务中希望翻译“
`<s>你好<e>`
”,就算目标语言字典中只有3个词(
`<s>`
,
`<e>`
,
`hello`
),也可能生成无限句话(
`hello`
循环出现的次数不定),为了找到其中较好的翻译结果,我们可采用柱搜索算法。
柱搜索(
[
beam search
](
http://en.wikipedia.org/wiki/Beam_search
)
)是一种启发式图搜索算法,用于在图或树中搜索有限集合中的最优扩展节点,通常用在解空间非常大的系统(如机器翻译、语音识别)中,原因是内存无法装下图或树中所有展开的解。如在机器翻译任务中希望翻译“
`<s>你好<e>`
”,就算目标语言字典中只有3个词(
`<s>`
,
`<e>`
,
`hello`
),也可能生成无限句话(
`hello`
循环出现的次数不定),为了找到其中较好的翻译结果,我们可采用柱搜索算法。
...
...
doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
浏览文件 @
2dc23ffa
...
@@ -149,6 +149,8 @@ def convolution_net(data, input_dim, class_dim, emb_dim, hid_dim):
...
@@ -149,6 +149,8 @@ def convolution_net(data, input_dim, class_dim, emb_dim, hid_dim):
网络的输入
`input_dim`
表示的是词典的大小,
`class_dim`
表示类别数。这里,我们使用
[
`sequence_conv_pool`
](
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py
)
API实现了卷积和池化操作。
网络的输入
`input_dim`
表示的是词典的大小,
`class_dim`
表示类别数。这里,我们使用
[
`sequence_conv_pool`
](
https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py
)
API实现了卷积和池化操作。
<a
name=
"栈值双向LSTM"
></a>
### 栈式双向LSTM
### 栈式双向LSTM
栈式双向神经网络
`stacked_lstm_net`
的代码片段如下:
栈式双向神经网络
`stacked_lstm_net`
的代码片段如下:
...
...
doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
浏览文件 @
2dc23ffa
...
@@ -50,7 +50,7 @@ similarity: -0.0997506977351
...
@@ -50,7 +50,7 @@ similarity: -0.0997506977351
```
```
以上结果可以通过运行
`calculate_dis.py`
, 加载字典里的单词和对应训练特征结果得到,我们将在
[
应用模型
](
#应用模型
)
中详细描述用法。
以上结果可以通过运行
`calculate_dis.py`
, 加载字典里的单词和对应训练特征结果得到,我们将在
[
模型应用
](
#模型应用
)
中详细描述用法。
## 模型概览
## 模型概览
...
@@ -189,6 +189,7 @@ dream that one day <e>
...
@@ -189,6 +189,7 @@ dream that one day <e>
最后,每个输入会按其单词次在字典里的位置,转化成整数的索引序列,作为PaddlePaddle的输入。
最后,每个输入会按其单词次在字典里的位置,转化成整数的索引序列,作为PaddlePaddle的输入。
<a
name=
"训练模型"
></a>
## 编程实现
## 编程实现
本配置的模型结构如下图所示:
本配置的模型结构如下图所示:
...
@@ -349,6 +350,7 @@ Step 20: Average Cost 5.766995
...
@@ -349,6 +350,7 @@ Step 20: Average Cost 5.766995
...
...
```
```
<a
name=
"模型应用"
></a>
## 模型应用
## 模型应用
在模型训练后,我们可以用它做一些预测。
在模型训练后,我们可以用它做一些预测。
...
...
doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
浏览文件 @
2dc23ffa
...
@@ -102,7 +102,7 @@ Softmax回归模型采用了最简单的两层神经网络,即只有输入层
...
@@ -102,7 +102,7 @@ Softmax回归模型采用了最简单的两层神经网络,即只有输入层
池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,并且能够在一定程度上控制过拟合。通常在卷积层的后面会加上一个池化层。池化包括最大池化、平均池化等。其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框的数取最大值作为输出层,如图6所示。
池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,并且能够在一定程度上控制过拟合。通常在卷积层的后面会加上一个池化层。池化包括最大池化、平均池化等。其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框的数取最大值作为输出层,如图6所示。
更详细的关于卷积神经网络的具体知识可以参考
[
斯坦福大学公开课
](
http://cs231n.github.io/convolutional-networks/
)
和
[
图像分类
](
https://github.com/PaddlePaddle/book/blob/develop/image_classification/README.md
)
教程。
更详细的关于卷积神经网络的具体知识可以参考
[
斯坦福大学公开课
](
http://cs231n.github.io/convolutional-networks/
)
和
[
图像分类
](
https://github.com/PaddlePaddle/book/tree/develop/03.image_classification
)
教程。
### 常见激活函数介绍
### 常见激活函数介绍
-
sigmoid激活函数: $ f(x) = sigmoid(x) =
\f
rac{1}{1+e^{-x}} $
-
sigmoid激活函数: $ f(x) = sigmoid(x) =
\f
rac{1}{1+e^{-x}} $
...
...
doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
浏览文件 @
2dc23ffa
...
@@ -149,7 +149,7 @@ python setup.py bdist_wheel
...
@@ -149,7 +149,7 @@ python setup.py bdist_wheel
pip install --upgrade dist/visualdl-
*
.whl
pip install --upgrade dist/visualdl-
*
.whl
```
```
如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/
how_to_dev_frontend_e
n.md)
如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/
develop/how_to_dev_frontend_c
n.md)
## SDK
## SDK
...
...
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
2dc23ffa
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
paddle/fluid/inference/api/demo_ci/run.sh
浏览文件 @
2dc23ffa
...
@@ -14,7 +14,7 @@ else
...
@@ -14,7 +14,7 @@ else
fi
fi
PREFIX
=
inference-vis-demos%2F
PREFIX
=
inference-vis-demos%2F
URL_ROOT
=
http://paddlemodels.
bj
.bcebos.com/
${
PREFIX
}
URL_ROOT
=
http://paddlemodels.
cdn
.bcebos.com/
${
PREFIX
}
# download vis_demo data
# download vis_demo data
function
download
()
{
function
download
()
{
...
...
paddle/fluid/operators/distributed/request_handler_impl.cc
浏览文件 @
2dc23ffa
...
@@ -39,8 +39,17 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -39,8 +39,17 @@ bool RequestSendHandler::Handle(const std::string& varname,
const
std
::
string
&
out_var_name
)
{
const
std
::
string
&
out_var_name
)
{
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
VLOG
(
4
)
<<
"RequestSendHandler:"
<<
varname
;
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
rpc_server_
->
Complete
();
}
else
{
// Async
// Async
if
(
!
sync_mode_
)
{
if
(
!
sync_mode_
)
{
VLOG
(
3
)
<<
"async process var: "
<<
varname
;
rpc_server_
->
Profiler
().
OneStep
();
rpc_server_
->
Profiler
().
OneStep
();
try
{
try
{
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
executor_
->
RunPreparedContext
((
*
grad_to_prepared_ctx_
)[
varname
].
get
(),
...
@@ -50,17 +59,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -50,17 +59,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
return
false
;
return
false
;
}
}
return
true
;
return
true
;
}
}
else
{
// sync
// Sync
if
(
varname
==
BATCH_BARRIER_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv BATCH_BARRIER_MESSAGE"
;
rpc_server_
->
IncreaseBatchBarrier
(
kRequestSend
);
}
else
if
(
varname
==
COMPLETE_MESSAGE
)
{
VLOG
(
3
)
<<
"sync: recv complete message"
;
rpc_server_
->
Complete
();
}
else
{
VLOG
(
3
)
<<
"sync: received var_name: "
<<
varname
;
rpc_server_
->
WaitCond
(
kRequestSend
);
rpc_server_
->
WaitCond
(
kRequestSend
);
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
VLOG
(
3
)
<<
"sync: processing received var: "
<<
varname
;
...
@@ -68,11 +67,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
...
@@ -68,11 +67,13 @@ bool RequestSendHandler::Handle(const std::string& varname,
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
LOG
(
FATAL
)
<<
"sync: Can not find server side var: "
<<
varname
;
return
false
;
return
false
;
}
}
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
if
(
invar
->
IsType
<
framework
::
SelectedRows
>
())
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_sparse_vars_
);
sparse_vars_
.
push_back
(
invar
);
sparse_vars_
.
push_back
(
invar
);
}
}
}
}
}
return
true
;
return
true
;
}
}
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
2dc23ffa
...
@@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE
(
!
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
),
auto
use_peepholes
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
);
"Do not support peephole yet."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
use_peepholes
?
7
:
4
)
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection"
,
"7 * %d if enable peepholes connection or"
frame_size
);
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
...
@@ -242,6 +242,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -242,6 +242,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
#define INIT_BASE_SIZES \
...
@@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
xx_data
,
bias
->
data
<
T
>
());
...
@@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
int
tstart
=
0
;
int
tstart
=
0
;
if
(
h0_data
)
{
if
(
h0_data
)
{
prev_h_data
=
h0_data
+
bid
*
D
;
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
}
else
{
// W_ch, W_ih, W_fh, W_oh
// If step == 0 and there is no initialized hidden state, that is to say
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
// the H0 is zeros. Then W_h * H_t-1 can be skipped
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
act_cand
(
D
,
xx_data
,
xx_data
);
// cell out= input*tilde
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
move_step
();
}
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
// W_ch, W_ih, W_fh, W_oh
// ~C_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_cand
(
D
,
xx_data
,
xx_data
);
act_cand
(
D
,
xx_data
,
xx_data
);
// a = forget * prev_cell
if
(
use_peepholes
)
{
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
xx_data
+
D
,
checked_cell_data
,
xx_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// b = input * tilde
// F_t * C_t-1
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
// cell out= a+b
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
...
@@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
move_step
();
}
}
// for each step in batch
}
}
// for each batch
}
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_INPUT_OUTPUT
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
// batch size == 1
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
return
;
return
;
}
}
...
@@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
...
@@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
...
@@ -396,17 +454,27 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -396,17 +454,27 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
T
*
prev_batch_h_data
=
nullptr
;
T
*
prev_batch_c_data
=
nullptr
;
T
*
cur_batch_in_data
=
batched_input_data
;
T
*
cur_batch_h_out_data
=
batched_h_out_data
;
T
*
cur_batch_c_out_data
=
batched_c_out_data
;
auto
move_step
=
[
&
](
int
bs
)
{
cur_batch_in_data
+=
bs
*
D4
;
cur_batch_c_out_data
+=
bs
*
D
;
cur_batch_h_out_data
+=
bs
*
D
;
};
int
tstart
=
0
;
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
if
(
h0
)
{
// reorder h0, c0
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_h_data
=
reordered_h0_data
;
prev_
batch_
h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
prev_
batch_
c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
...
@@ -415,71 +483,122 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -415,71 +483,122 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
}
else
{
}
else
{
// compute without h0, c0
// Compute with no H0/C0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
// W_ch, W_ih, W_fh, W_oh
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// If step == 0 and there is no initialized hidden state, that is to say
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// the H0 is zeros. Then W_h * H_t-1 can be skiped
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// iterate each data in 1st batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// cell out= input*tilde
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
//
add offset
//
move to next data in the same batch
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
max_bs
);
tstart
=
1
;
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
// Then start from next
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
prev_batch_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
batched_input_data
,
D4
);
cur_batch_in_data
,
D4
);
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_prev_c_data
=
prev_c_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
prev_c_data
=
prev_batch_c_data
;
// NULL if no C0 in step0
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
T
*
prev_h_data
=
prev_batch_h_data
;
// NULL if no H0 in step0
// W_ch, W_ih, W_fh, W_oh
auto
next_data_in_batch
=
[
&
]()
{
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
prev_c_data
=
prev_c_data
?
prev_c_data
+
D
:
nullptr
;
prev_h_data
=
prev_h_data
?
prev_h_data
+
D
:
nullptr
;
};
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// iterate each data in same batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_prev_c_data
,
cur_in_data
+
D2
);
if
(
use_peepholes
)
{
// b = input * tilde
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
cur_in_data
+
D
,
checked_cell_data
,
cur_in_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
prev_c_data
,
cur_in_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
//
cell out= a+b
//
C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
cur_in_data
+=
D4
;
// move to next data in same batch
cur_prev_c_data
+=
D
;
next_data_in_batch
();
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
}
// move to data for next timestep
prev_c_data
=
batched_c_out_data
;
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_h_data
=
batched_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
batched_c_out_data
=
cur_c_out_data
;
move_step
(
cur_bs
);
batched_h_out_data
=
cur_h_out_data
;
batched_input_data
=
cur_in_data
;
}
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
2dc23ffa
...
@@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
...
@@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5)
top5_values, top5_indices = layers.topk(input, k=5)
"""
"""
shape
=
input
.
shape
if
k
<
1
or
k
>=
shape
[
-
1
]:
raise
ValueError
(
"k must be greater than 0 and less than %d."
%
(
shape
[
-
1
]))
helper
=
LayerHelper
(
"top_k"
,
**
locals
())
helper
=
LayerHelper
(
"top_k"
,
**
locals
())
values
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
values
=
helper
.
create_tmp_variable
(
dtype
=
input
.
dtype
)
indices
=
helper
.
create_tmp_variable
(
dtype
=
"int64"
)
indices
=
helper
.
create_tmp_variable
(
dtype
=
"int64"
)
...
...
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
浏览文件 @
2dc23ffa
...
@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
...
@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
use_peepholes
=
False
self
.
use_peepholes
=
False
self
.
use_seq
=
False
self
.
set_conf
()
self
.
set_conf
()
T
=
sum
(
self
.
lod
[
0
])
T
=
sum
(
self
.
lod
[
0
])
...
@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
...
@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_peepholes'
:
self
.
use_peepholes
,
'use_seq'
:
self
.
use_seq
,
'is_reverse'
:
self
.
is_reverse
,
'is_reverse'
:
self
.
is_reverse
,
'gate_activation'
:
self
.
act_gate
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
'cell_activation'
:
self
.
act_cell
,
...
@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
...
@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
self
.
D
=
16
self
.
D
=
16
class
TestFusionLSTMOpPeepholes
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
class
TestFusionLSTMOpPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpPoopholesBS1
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
lod
=
[[
3
]]
self
.
D
=
16
class
TestFusionLSTMOpSeqInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqPeepholes
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
class
TestFusionLSTMOpSeqPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录