Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ede4b230
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ede4b230
编写于
9月 29, 2018
作者:
T
tensor-tang
提交者:
GitHub
9月 29, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13553 from jczaja/prv-fused_embedding_fc_lstm_op
Adding fused_embedding_fc_lstm op
上级
618b3297
e202f33a
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
987 addition
and
9 deletion
+987
-9
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+243
-0
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
+40
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+18
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+17
-0
paddle/fluid/inference/analysis/analyzer.h
paddle/fluid/inference/analysis/analyzer.h
+9
-8
paddle/fluid/inference/api/paddle_inference_api.h
paddle/fluid/inference/api/paddle_inference_api.h
+1
-1
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
...nference/tests/api/analyzer_text_classification_tester.cc
+13
-0
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
+604
-0
paddle/fluid/operators/fused_embedding_fc_lstm_op.h
paddle/fluid/operators/fused_embedding_fc_lstm_op.h
+41
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
ede4b230
...
...
@@ -34,6 +34,7 @@ endif ()
pass_library
(
attention_lstm_fuse_pass inference
)
pass_library
(
infer_clean_graph_pass inference
)
pass_library
(
fc_lstm_fuse_pass inference
)
pass_library
(
embedding_fc_lstm_fuse_pass inference
)
pass_library
(
fc_gru_fuse_pass inference
)
pass_library
(
seq_concat_fc_fuse_pass inference
)
...
...
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
0 → 100644
浏览文件 @
ede4b230
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
int
BuildFusion
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
,
bool
with_fc_bias
)
{
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Build pattern
PDNode
*
x
=
pattern
->
NewNode
(
patterns
::
PDNodeName
(
name_scope
,
"x"
))
->
assert_is_op_input
(
"lookup_table"
)
->
assert_var_not_persistable
();
patterns
::
Embedding
embedding_pattern
(
pattern
,
name_scope
);
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
// but lookup table output may go into other LSTM (for reverse
// direction)
auto
*
embedding_out
=
embedding_pattern
(
x
);
patterns
::
FC
fc_pattern
(
pattern
,
name_scope
);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto
*
fc_out
=
fc_pattern
(
embedding_out
,
with_fc_bias
)
->
AsIntermediate
();
patterns
::
LSTM
lstm_pattern
(
pattern
,
name_scope
);
lstm_pattern
(
fc_out
);
// Create New OpDesc
auto
embedding_lstm_creator
=
[
&
](
Node
*
embedding
,
Node
*
W
,
Node
*
lstm
,
Node
*
input
,
Node
*
weight_x
,
Node
*
weight_h
,
Node
*
bias
,
Node
*
hidden
,
Node
*
cell
,
Node
*
xx
,
Node
*
fc_bias
)
{
OpDesc
op_desc
;
op_desc
.
SetType
(
"fused_embedding_fc_lstm"
);
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN
(
Ids
,
input
);
SET_IN
(
WeightH
,
weight_h
);
// Neet to have this passed as We need Wc data for peephole connections
SET_IN
(
Bias
,
bias
);
#undef SET_IN
// Multiply embeddings with Weights
PADDLE_ENFORCE
(
scope
);
const
std
::
string
&
embeddings
=
patterns
::
UniqueKey
(
"Embeddings"
);
auto
*
embeddings_var
=
scope
->
Var
(
embeddings
);
PADDLE_ENFORCE
(
embeddings_var
);
auto
*
embeddings_tensor
=
embeddings_var
->
GetMutable
<
framework
::
LoDTensor
>
();
// Get WeightX size: [single_embedding, fc_size]
// and embedding size: [dict_size, single_embedding]
// and create new size of embeddings eg. [dict_size , hidden_size]
auto
*
embedding_var
=
scope
->
FindVar
(
W
->
Name
());
PADDLE_ENFORCE
(
embedding_var
);
const
auto
&
embedding_tensor
=
embedding_var
->
Get
<
framework
::
LoDTensor
>
();
const
auto
&
weightx_tensor
=
scope
->
FindVar
(
weight_x
->
Name
())
->
Get
<
framework
::
LoDTensor
>
();
embeddings_tensor
->
Resize
(
{
embedding_tensor
.
dims
()[
0
],
weightx_tensor
.
dims
()[
1
]});
// Multiplie embeddings via WeightsX and add bias
auto
embedding_data
=
embedding_tensor
.
data
<
float
>
();
auto
weightx_data
=
weightx_tensor
.
data
<
float
>
();
auto
embeddings_data
=
embeddings_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// Adding biases to GEMM result to be
auto
*
lstm_bias_var
=
scope
->
FindVar
(
bias
->
Name
());
PADDLE_ENFORCE
(
lstm_bias_var
);
const
auto
&
lstm_bias_tensor
=
lstm_bias_var
->
Get
<
framework
::
LoDTensor
>
();
auto
alpha
=
1.0
f
;
auto
beta
=
1.0
f
;
int
m
=
embedding_tensor
.
dims
()[
0
];
int
n
=
weightx_tensor
.
dims
()[
1
];
int
k
=
embedding_tensor
.
dims
()[
1
];
// Copy only gate biases values (only actual bias data, not peephole
// weights)
std
::
vector
<
float
>
combined_biases
;
combined_biases
.
reserve
(
n
);
std
::
copy_n
(
lstm_bias_tensor
.
data
<
float
>
(),
n
,
std
::
back_inserter
(
combined_biases
));
if
(
with_fc_bias
)
{
// Add FC-bias with LSTM-bias (into GEMM result to be)
auto
*
fc_bias_var
=
scope
->
FindVar
(
fc_bias
->
Name
());
const
auto
&
fc_bias_tensor
=
fc_bias_var
->
Get
<
framework
::
LoDTensor
>
();
for
(
int
i
=
0
;
i
<
fc_bias_tensor
.
numel
();
i
++
)
{
combined_biases
[
i
]
+=
fc_bias_tensor
.
data
<
float
>
()[
i
];
}
}
// broadcast biases
std
::
vector
<
float
>
ones
(
m
,
1.0
f
);
paddle
::
operators
::
math
::
CBlas
<
float
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
1
,
alpha
,
&
ones
[
0
],
1
,
&
combined_biases
[
0
],
n
,
0.0
f
,
embeddings_data
,
n
);
// Wx*embeddings + biases
paddle
::
operators
::
math
::
CBlas
<
float
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
alpha
,
embedding_data
,
k
,
weightx_data
,
n
,
beta
,
embeddings_data
,
n
);
op_desc
.
SetInput
(
"Embeddings"
,
{
embeddings
});
// Create temp variables.
const
std
::
string
BatchedInput
=
patterns
::
UniqueKey
(
"BatchedInput"
);
const
std
::
string
BatchedCellPreAct
=
patterns
::
UniqueKey
(
"BatchedCellPreAct"
);
const
std
::
string
BatchedGate
=
patterns
::
UniqueKey
(
"BatchedGate"
);
scope
->
Var
(
BatchedInput
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedCellPreAct
)
->
GetMutable
<
framework
::
LoDTensor
>
();
scope
->
Var
(
BatchedGate
)
->
GetMutable
<
framework
::
LoDTensor
>
();
op_desc
.
SetInput
(
"H0"
,
{});
op_desc
.
SetInput
(
"C0"
,
{});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
BatchedGate
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
BatchedCellPreAct
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
BatchedInput
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
op_desc
.
SetAttr
(
"use_seq"
,
true
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedC0
);
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
IR_NODE_LINK_TO
(
input
,
op
);
IR_NODE_LINK_TO
(
weight_x
,
op
);
IR_NODE_LINK_TO
(
weight_h
,
op
);
IR_NODE_LINK_TO
(
bias
,
op
);
IR_NODE_LINK_TO
(
op
,
hidden
);
return
op
;
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
lstm
,
lstm
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Weight
,
Weight
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Bias
,
Bias
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Cell
,
Cell
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
Hidden
,
Hidden
,
lstm_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
lookup_table
,
lookup_table
,
embedding_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
W
,
W
,
embedding_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
w
,
w
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
mul
,
mul
,
fc_pattern
);
// TODO(jczaja): Add support for is_sparse / is_distributed
auto
is_sparse
=
boost
::
get
<
bool
>
(
lookup_table
->
Op
()
->
GetAttr
(
"is_sparse"
));
auto
is_distributed
=
boost
::
get
<
bool
>
(
lookup_table
->
Op
()
->
GetAttr
(
"is_distributed"
));
if
(
is_sparse
==
true
||
is_distributed
==
true
)
{
return
;
}
if
(
with_fc_bias
)
{
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
Out
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
fc_bias
,
bias
,
fc_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
elementwise_add
,
elementwise_add
,
fc_pattern
);
embedding_lstm_creator
(
lookup_table
,
W
,
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
fc_bias
);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
std
::
unordered_set
<
const
Node
*>
marked_nodes
(
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
{
mul
,
lstm
,
elementwise_add
,
fc_bias
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
else
{
GET_IR_NODE_FROM_SUBGRAPH
(
fc_out
,
mul_out
,
fc_pattern
);
embedding_lstm_creator
(
lookup_table
,
W
,
lstm
,
subgraph
.
at
(
x
),
w
,
Weight
,
Bias
,
Hidden
,
Cell
,
fc_out
,
nullptr
);
// Remove unneeded nodes.
// TODO(jczaja): Proper removing of lookup table
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
// lstm});
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
mul
,
lstm
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
}
++
fusion_count
;
};
gpd
(
graph
,
handler
);
return
fusion_count
;
}
std
::
unique_ptr
<
ir
::
Graph
>
EmbeddingFCLSTMFusePass
::
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
.
get
());
int
fusion_count
=
BuildFusion
(
graph
.
get
(),
name_scope_
,
param_scope
(),
true
/*with_fc_bias*/
);
AddStatis
(
fusion_count
);
return
graph
;
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
embedding_fc_lstm_fuse_pass
,
paddle
::
framework
::
ir
::
EmbeddingFCLSTMFusePass
);
paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h
0 → 100644
浏览文件 @
ede4b230
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// Fusing of Embedding , FC and LSTM op
// Just FC without bias
class
EmbeddingFCLSTMFusePass
:
public
FusePassBase
{
public:
virtual
~
EmbeddingFCLSTMFusePass
()
{}
protected:
std
::
unique_ptr
<
ir
::
Graph
>
ApplyImpl
(
std
::
unique_ptr
<
ir
::
Graph
>
graph
)
const
;
const
std
::
string
name_scope_
{
"embedding_fc_lstm_fuse"
};
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
ede4b230
...
...
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode
*
patterns
::
Embedding
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lookup_table"
,
"Ids"
);
auto
*
lookup_table_op
=
pattern
->
NewNode
(
lookup_table_repr
())
->
assert_is_op
(
"lookup_table"
);
#define NEW_NODE(arg__, io__) \
auto *arg__ = pattern->NewNode(arg__##_repr()) \
->assert_is_op_##io__("lookup_table", #arg__);
NEW_NODE
(
W
,
input
);
NEW_NODE
(
Out
,
output
);
#undef NEW_NODE
lookup_table_op
->
LinksFrom
({
x
,
W
});
lookup_table_op
->
LinksTo
({
Out
});
return
Out
;
}
PDNode
*
patterns
::
LSTM
::
operator
()(
PDNode
*
x
)
{
x
->
assert_is_op_input
(
"lstm"
,
"Input"
);
auto
*
lstm_op
=
pattern
->
NewNode
(
lstm_repr
())
->
assert_is_op
(
"lstm"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
ede4b230
...
...
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
PATTERN_DECL_NODE
(
Out
);
};
// Embedding
struct
Embedding
:
public
PatternBase
{
Embedding
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"embedding"
)
{}
PDNode
*
operator
()(
PDNode
*
x
);
// declare operator node's name
PATTERN_DECL_NODE
(
lookup_table
);
// Inputs
//
PATTERN_DECL_NODE
(
Ids
);
PATTERN_DECL_NODE
(
W
);
// embeddings
// Outputs
PATTERN_DECL_NODE
(
Out
);
};
struct
LSTM
:
public
PatternBase
{
LSTM
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"lstm"
)
{}
...
...
paddle/fluid/inference/analysis/analyzer.h
浏览文件 @
ede4b230
...
...
@@ -66,6 +66,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
// Manual update the passes here.
"infer_clean_graph_pass"
,
//
"attention_lstm_fuse_pass"
,
//
"embedding_fc_lstm_fuse_pass"
,
//
"fc_lstm_fuse_pass"
,
//
"mul_lstm_fuse_pass"
,
//
"fc_gru_fuse_pass"
,
//
...
...
paddle/fluid/inference/api/paddle_inference_api.h
浏览文件 @
ede4b230
...
...
@@ -263,7 +263,7 @@ struct AnalysisConfig : public NativeConfig {
bool
enable_ir_optim
=
true
;
// Manually determine the IR passes to run.
IrPassMode
ir_mode
{
IrPassMode
::
kExclude
};
std
::
vector
<
std
::
string
>
ir_passes
;
std
::
vector
<
std
::
string
>
ir_passes
{
"embedding_fc_lstm_fuse_pass"
}
;
// NOT stable yet.
bool
use_feed_fetch_ops
{
true
};
...
...
paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc
浏览文件 @
ede4b230
...
...
@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
CompareNativeAndAnalysis
(
cfg
,
input_slots_all
);
}
TEST
(
Analyzer_Text_Classification
,
compare_against_embedding_fc_lstm_fused
)
{
AnalysisConfig
cfg
;
SetConfig
(
&
cfg
);
// Enable embedding_fc_lstm_fuse_pass (disabled by default)
auto
it
=
std
::
find
(
cfg
.
ir_passes
.
begin
(),
cfg
.
ir_passes
.
end
(),
"embedding_fc_lstm_fuse_pass"
);
if
(
it
!=
cfg
.
ir_passes
.
end
())
cfg
.
ir_passes
.
erase
(
it
);
std
::
vector
<
std
::
vector
<
PaddleTensor
>>
input_slots_all
;
SetInput
(
&
input_slots_all
);
CompareNativeAndAnalysis
(
cfg
,
input_slots_all
);
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/operators/fused_embedding_fc_lstm_op.cc
0 → 100644
浏览文件 @
ede4b230
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_embedding_fc_lstm_op.h"
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace
paddle
{
namespace
operators
{
void
FusedEmbeddingFCLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Embeddings"
),
"Assert only one Input(Embeddings) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
"Assert only one Input(WeightH) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Assert only one Input(Bias) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Assert only one Output(XX) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Assert only one Output(Hidden) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Assert only one Output(Cell) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input(Ids) of LookupTableOp should not be null."
);
auto
table_dims
=
ctx
->
GetInputDim
(
"Embeddings"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
int
ids_rank
=
ids_dims
.
size
();
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_rank
-
1
],
1
,
"The last dimension of the 'Ids' tensor must be 1."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"Ids"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(Ids)'s rank must be 2."
);
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
"The dimension of Input(H0) and Input(C0) "
"should be the same."
);
}
auto
embeddings_dims
=
ctx
->
GetInputDim
(
"Embeddings"
);
PADDLE_ENFORCE_EQ
(
embeddings_dims
.
size
(),
2
,
"The rank of Input(Embeddings) should be 2."
);
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
int
frame_size
=
wh_dims
[
1
]
/
4
;
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2."
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
"The first dimension of Input(WeightH) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(WeightH) "
"should be 4 * %d."
,
frame_size
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
)
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
ShareLoD
(
"Ids"
,
"Hidden"
);
ctx
->
ShareLoD
(
"Ids"
,
"Cell"
);
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wh_dims
[
1
];
}
else
{
xx_width
=
x_dims
[
1
]
>
wh_dims
[
1
]
?
wh_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Assert only one Output(BatchedInput) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Assert only one Output(BatchedHidden) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Assert only one Output(BatchedCell) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Assert only one Output(ReorderedH0) of LSTM"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Assert only one Output(ReorderedC0) of LSTM."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wh_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"Ids"
,
"XX"
);
}
framework
::
OpKernelType
FusedEmbeddingFCLSTMOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Embeddings"
)
->
type
()),
ctx
.
device_context
());
}
void
FusedEmbeddingFCLSTMOpMaker
::
Make
()
{
AddInput
(
"Ids"
,
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"The last dimension size must be 1."
);
AddInput
(
"Embeddings"
,
"(Tensor) the learnable weights of X."
" - The shape is (M x 4D), where M is the dim size of x, D is the "
"hidden size. "
" - Weight = {W_cx, W_ix, W_fx, W_ox}"
);
AddInput
(
"WeightH"
,
"(Tensor) same as LSTMOp, the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
" - Weight = {W_ch, W_ih, W_fh, W_oh}"
);
AddInput
(
"Bias"
,
"(Tensor) the learnable weights. Almost same as LSTMOp"
"Note: we should add the fc bias into this (1x4D) in bias."
"input-hidden bias weight and peephole connections weight if "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."
);
AddInput
(
"H0"
,
"(Tensor, optional) (same as LSTMOp) the initial hidden state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size and D is the hidden size."
)
.
AsDispensable
();
AddInput
(
"C0"
,
"(Tensor, optional) (same as LSTMOp) (the initial cell state is an "
"optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time."
)
.
AsDispensable
();
AddOutput
(
"Hidden"
,
"(LoDTensor) (same as LSTMOp) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"Cell"
,
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"XX"
,
"(LoDTensor) the result after X * WeightX (size is T x 4D)"
" or batched_X (size is T x M), this will be automatically chosen,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input."
)
.
AsIntermediate
();
AddOutput
(
"BatchedInput"
,
"(LoDTensor) (T x 4D)."
).
AsIntermediate
();
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"BatchedCell"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedH0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedC0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_reverse"
,
"(bool, defalut: False) "
"whether to compute reversed LSTM."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_seq"
,
"(bool, defalut: True) "
"whether to use seq mode to compute."
)
.
SetDefault
(
true
);
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default."
)
.
SetDefault
(
"sigmoid"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddAttr
<
std
::
string
>
(
"cell_activation"
,
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut."
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddAttr
<
std
::
string
>
(
"candidate_activation"
,
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default."
)
.
SetDefault
(
"tanh"
)
.
InEnum
({
"sigmoid"
,
"tanh"
,
"relu"
,
"identity"
});
AddComment
(
R"DOC(
Fusion Long-Short Term Memory (LSTM) Operator.
This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"
);
}
template
<
typename
T
>
class
FusedEmbeddingFCLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* ids = ctx.Input<LoDTensor>("Ids"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* embeddings = ctx.Input<Tensor>("Embeddings"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes");
#define INIT_BASE_SIZES \
auto ids_dims = ids->dims();
/* T x M*/
\
auto ids_numel = ids->numel();
/* T x 1*/
\
auto wh_dims = wh->dims();
/* D x 4D*/
\
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
int64_t row_number = embeddings->dims()[0]; \
int64_t row_width = embeddings->dims()[1]; \
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const int64_t* ids_data = ids->data<int64_t>(); \
const T* embeddings_data = embeddings->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
act_gate(D3, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
// std::cout << "====> SeqCompute" << std::endl;
auto
ids_lod
=
ids
->
lod
();
const
int
total_T
=
ids_dims
[
0
];
const
int
N
=
ids_lod
[
0
].
size
()
-
1
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
h_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
T
*
c_out_data
=
cell_out
->
mutable_data
<
T
>
(
place
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
xx_data
+
i
*
row_width
,
embeddings_data
+
ids_data
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
int
xx_offset
=
D4
;
int
gate_offset
=
D
;
if
(
is_reverse
)
{
const
int
offset
=
(
total_T
-
1
)
*
D
;
xx_data
=
xx_data
+
offset
*
4
;
h_out_data
=
h_out_data
+
offset
;
c_out_data
=
c_out_data
+
offset
;
xx_offset
=
-
D4
;
gate_offset
=
-
D
;
}
#define MOVE_ONE_STEP \
prev_h_data = h_out_data; \
prev_c_data = c_out_data; \
xx_data = xx_data + xx_offset; \
h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
#define PROCESS_H0C0_DEFINES \
int bid = is_reverse ? N - 1 - i : i; \
int seq_len = ids_lod[0][bid + 1] - ids_lod[0][bid]; \
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0
#define PROCESS_H0C0_PEEPHOLE \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if
(
use_peepholes
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0_PEEPHOLE
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
if
(
ids
->
lod
()[
0
].
size
()
==
2
)
{
SeqCompute
(
ctx
);
return
;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
// std::cout << "===> Batch Compute" << std::endl;
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_h_out_data
=
batched_h_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
row_number
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
"ids %d"
,
i
);
memcpy
(
xx_data
+
i
*
row_width
,
embeddings_data
+
ids_data
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
to_batch
(
dev_ctx
,
*
xx
,
batched_input
,
true
,
is_reverse
);
auto
batched_lod
=
batched_input
->
lod
();
const
auto
&
seq_order
=
batched_lod
[
2
];
const
int
max_bs
=
seq_order
.
size
();
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
if
(
use_peepholes
)
{
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
#define DEFINE_CUR \
T* cur_in_data = batched_input_data; \
T* cur_prev_c_data = prev_c_data; \
T* cur_c_out_data = batched_c_out_data; \
T* cur_h_out_data = batched_h_out_data
#define MOVE_ONE_BATCH \
cur_in_data += D4; \
cur_prev_c_data += D; \
cur_c_out_data += D; \
cur_h_out_data += D
#define MOVE_ONE_STEP \
prev_c_data = batched_c_out_data; \
prev_h_data = batched_h_out_data; \
batched_c_out_data = cur_c_out_data; \
batched_h_out_data = cur_h_out_data; \
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
else
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_h_out
,
hidden_out
);
batched_c_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_c_out
,
cell_out
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
Attr
<
bool
>
(
"use_seq"
))
{
SeqCompute
(
ctx
);
}
else
{
BatchCompute
(
ctx
);
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fused_embedding_fc_lstm
,
ops
::
FusedEmbeddingFCLSTMOp
,
ops
::
FusedEmbeddingFCLSTMOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OP_CPU_KERNEL
(
fused_embedding_fc_lstm
,
ops
::
FusedEmbeddingFCLSTMKernel
<
float
>
,
ops
::
FusedEmbeddingFCLSTMKernel
<
double
>
);
paddle/fluid/operators/fused_embedding_fc_lstm_op.h
0 → 100644
浏览文件 @
ede4b230
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
Tensor
=
framework
::
Tensor
;
class
FusedEmbeddingFCLSTMOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusedEmbeddingFCLSTMOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace operators
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录