diff --git a/Dockerfile b/Dockerfile
index 402adee2ea2822250ebc8f6229fd6a44545d58e5..634be18a51bf61e96a8bf6f263b6674a7932d6e4 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -53,7 +53,7 @@ RUN curl -s -q https://glide.sh/get | sh
# and its size is only one-third of the official one.
# 2. Manually add ~IPluginFactory() in IPluginFactory class of NvInfer.h, otherwise, it couldn't work in paddle.
# See https://github.com/PaddlePaddle/Paddle/issues/10129 for details.
-RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
+RUN wget -qO- http://paddlepaddledeps.cdn.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
tar -xz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \
cp -rf /usr/local/TensorRT/lib /usr
diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake
index bc36683a9facc253e7b9feb0c5a56e79491fb9b0..f61770514eb05a99c140cdb18575c89aa5235c14 100644
--- a/cmake/inference_lib.cmake
+++ b/cmake/inference_lib.cmake
@@ -128,16 +128,13 @@ set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
set(dst_dir "${FLUID_INSTALL_DIR}/paddle/fluid")
set(module "framework")
if (NOT WIN32)
-copy(framework_lib DEPS framework_py_proto
- SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
- DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module}
-)
-else()
-copy(framework_lib
+set(framework_lib_deps framework_py_proto)
+endif(NOT WIN32)
+copy(framework_lib DEPS ${framework_lib_deps}
SRCS ${src_dir}/${module}/*.h ${src_dir}/${module}/details/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
- DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module}
+ ${src_dir}/${module}/ir/*.h
+ DSTS ${dst_dir}/${module} ${dst_dir}/${module}/details ${dst_dir}/${module} ${dst_dir}/${module}/ir
)
-endif(NOT WIN32)
set(module "memory")
copy(memory_lib
@@ -161,7 +158,8 @@ set(module "inference")
copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/demo_ci
- DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
+ ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
+ DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
)
set(module "platform")
diff --git a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
index fa2b930be0d26d816566599cece8afbedc1157e0..6e5f77fec8a894c390ced8c93ee344fd8d27370e 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/machine_translation/README.cn.md
@@ -60,6 +60,7 @@
图3. 编码器-解码器框架
+
#### 编码器
编码阶段分为三步:
@@ -81,7 +82,7 @@
机器翻译任务的训练过程中,解码阶段的目标是最大化下一个正确的目标语言词的概率。思路是:
1. 每一个时刻,根据源语言句子的编码信息(又叫上下文向量,context vector)`$c$`、真实目标语言序列的第`$i$`个词`$u_i$`和`$i$`时刻RNN的隐层状态`$z_i$`,计算出下一个隐层状态`$z_{i+1}$`。计算公式如下:
$$z_{i+1}=\phi_{\theta '} \left ( c,u_i,z_i \right )$$
-其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用[注意力机制](#注意力机制)时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。
+其中`$\phi _{\theta '}$`是一个非线性激活函数;`$c=q\mathbf{h}$`是源语言句子的上下文向量,在不使用注意力机制时,如果[编码器](#编码器)的输出是源语言句子编码后的最后一个元素,则可以定义`$c=h_T$`;`$u_i$`是目标语言序列的第`$i$`个单词,`$u_0$`是目标语言序列的开始标记``,表示解码开始;`$z_i$`是`$i$`时刻解码RNN的隐层状态,`$z_0$`是一个全零的向量。
2. 将`$z_{i+1}$`通过`softmax`归一化,得到目标语言序列的第`$i+1$`个单词的概率分布`$p_{i+1}$`。概率分布公式如下:
$$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
@@ -93,6 +94,7 @@ $$p\left ( u_{i+1}|u_{<i+1},\mathbf{x} \right )=softmax(W_sz_{i+1}+b_z)$$
机器翻译任务的生成过程,通俗来讲就是根据预先训练的模型来翻译源语言句子。生成过程中的解码阶段和上述训练过程的有所差异,具体介绍请见[柱搜索算法](#柱搜索算法)。
+
### 柱搜索算法
柱搜索([beam search](http://en.wikipedia.org/wiki/Beam_search))是一种启发式图搜索算法,用于在图或树中搜索有限集合中的最优扩展节点,通常用在解空间非常大的系统(如机器翻译、语音识别)中,原因是内存无法装下图或树中所有展开的解。如在机器翻译任务中希望翻译“`你好`”,就算目标语言字典中只有3个词(``, ``, `hello`),也可能生成无限句话(`hello`循环出现的次数不定),为了找到其中较好的翻译结果,我们可采用柱搜索算法。
diff --git a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
index 9900dfb9a67dc6f8940bd7dd3abfa15ac8a3488f..8477cf32146c33947ced447c8bdd287a3e1e71f5 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/understand_sentiment/README.cn.md
@@ -149,6 +149,8 @@ def convolution_net(data, input_dim, class_dim, emb_dim, hid_dim):
网络的输入`input_dim`表示的是词典的大小,`class_dim`表示类别数。这里,我们使用[`sequence_conv_pool`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/networks.py) API实现了卷积和池化操作。
+
+
### 栈式双向LSTM
栈式双向神经网络`stacked_lstm_net`的代码片段如下:
diff --git a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
index 2c68cdac4f10319359b74bc92569dfd3f65380b5..904d99fe2ffc9ead69a86c9763568a5c098348d5 100644
--- a/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/basics/word2vec/README.cn.md
@@ -50,7 +50,7 @@ similarity: -0.0997506977351
```
-以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[应用模型](#应用模型)中详细描述用法。
+以上结果可以通过运行`calculate_dis.py`, 加载字典里的单词和对应训练特征结果得到,我们将在[模型应用](#模型应用)中详细描述用法。
## 模型概览
@@ -189,6 +189,7 @@ dream that one day
最后,每个输入会按其单词次在字典里的位置,转化成整数的索引序列,作为PaddlePaddle的输入。
+
## 编程实现
本配置的模型结构如下图所示:
@@ -349,6 +350,7 @@ Step 20: Average Cost 5.766995
...
```
+
## 模型应用
在模型训练后,我们可以用它做一些预测。
diff --git a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
index e6f89b23a95d1a07565f3e0a285e9c3f921930df..ac36c4ecf6b9b716fe5f0dbe2346e64918c22242 100644
--- a/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
+++ b/doc/fluid/new_docs/beginners_guide/quick_start/recognize_digits/README.cn.md
@@ -102,7 +102,7 @@ Softmax回归模型采用了最简单的两层神经网络,即只有输入层
池化是非线性下采样的一种形式,主要作用是通过减少网络的参数来减小计算量,并且能够在一定程度上控制过拟合。通常在卷积层的后面会加上一个池化层。池化包括最大池化、平均池化等。其中最大池化是用不重叠的矩形框将输入层分成不同的区域,对于每个矩形框的数取最大值作为输出层,如图6所示。
-更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类](https://github.com/PaddlePaddle/book/blob/develop/image_classification/README.md)教程。
+更详细的关于卷积神经网络的具体知识可以参考[斯坦福大学公开课]( http://cs231n.github.io/convolutional-networks/ )和[图像分类]( https://github.com/PaddlePaddle/book/tree/develop/03.image_classification )教程。
### 常见激活函数介绍
- sigmoid激活函数: $ f(x) = sigmoid(x) = \frac{1}{1+e^{-x}} $
diff --git a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
index a2f30823a6fcd379f94e6e98d043b0d00681827f..99f8bee5ca1519ccf5d7c35ad2a64da4a8841ada 100644
--- a/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
+++ b/doc/fluid/new_docs/user_guides/howto/debug/visualdl.md
@@ -104,6 +104,7 @@ visualDL --logdir=scratch_log --port=8080
# 访问 http://127.0.0.1:8080
```
+如果出现`TypeError: __init__() got an unexpected keyword argument 'file'`, 是因为protobuf不是3.5以上,运行`pip install --upgrade protobuf`就能解决。
如果在虚拟环境下仍然遇到安装问题,请尝试以下方法。
@@ -149,7 +150,7 @@ python setup.py bdist_wheel
pip install --upgrade dist/visualdl-*.whl
```
-如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/how_to_dev_frontend_en.md)
+如果打包和安装遇到其他问题,不安装只想运行Visual DL可以看[这里](https://github.com/PaddlePaddle/VisualDL/blob/develop/docs/develop/how_to_dev_frontend_cn.md)
## SDK
diff --git a/doc/fluid/new_docs/user_guides/howto/inference/native_infer.rst b/doc/fluid/new_docs/user_guides/howto/inference/native_infer.rst
index 21a6fe5cf54d0c0c760ade4ba602024ffa29675f..6d6f3035c0b5c985cd39d45df9f1bcce50dcefa0 100644
--- a/doc/fluid/new_docs/user_guides/howto/inference/native_infer.rst
+++ b/doc/fluid/new_docs/user_guides/howto/inference/native_infer.rst
@@ -4,13 +4,12 @@ Paddle 预测 API
为了更简单方便的预测部署,Fluid 提供了一套高层 API
用来隐藏底层不同的优化实现。
-`预测库相关代码 `__
+`预测库相关代码 `_
包括
- 头文件 ``paddle_inference_api.h`` 定义了所有的接口
- 库文件\ ``libpaddle_fluid.so`` 或 ``libpaddle_fluid.a``
-- 库文件 ``libpaddle_inference_api.so`` 或
- ``libpaddle_inference_api.a``
+
编译和依赖可以参考 :ref:`install_or_build_cpp_inference_lib` 。
@@ -97,8 +96,7 @@ engine
CHECK(predictor->Run(slots, &outputs));
// 获取 outputs ...
-编译时,联编 ``libpaddle_fluid.a/.so`` 和
-``libpaddle_inference_api.a/.so`` 便可。
+编译时,联编 ``libpaddle_fluid.a/.so`` 便可。
详细代码参考
------------
diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec
index bb5f2894c08b5d8941ad8914f6b83280aa053e37..c2694144d708161a3bed214ceca745505656456f 100644
--- a/paddle/fluid/API.spec
+++ b/paddle/fluid/API.spec
@@ -43,6 +43,7 @@ paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list',
paddle.fluid.global_scope ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.scope_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.Trainer.__init__ ArgSpec(args=['self', 'train_func', 'optimizer_func', 'param_path', 'place', 'parallel', 'checkpoint_config'], varargs=None, keywords=None, defaults=(None, None, False, None))
+paddle.fluid.Trainer.save_inference_model ArgSpec(args=['self', 'param_path', 'feeded_var_names', 'target_var_indexes'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Trainer.save_params ArgSpec(args=['self', 'param_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Trainer.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Trainer.test ArgSpec(args=['self', 'reader', 'feed_order'], varargs=None, keywords=None, defaults=None)
@@ -312,7 +313,7 @@ paddle.fluid.layers.iou_similarity ArgSpec(args=[], varargs='args', keywords='kw
paddle.fluid.layers.box_coder ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.polygon_box_transform ArgSpec(args=[], varargs='args', keywords='kwargs', defaults=None)
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
-paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 200, 1))
+paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk'], varargs=None, keywords=None, defaults=('ROC', 4095, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.natural_exp_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.inverse_time_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
@@ -376,7 +377,7 @@ paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ ArgSpec(args=['self', 'l
paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.FtrlOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power'], varargs=None, keywords='kwargs', defaults=(0.0, 0.0, -0.5))
paddle.fluid.optimizer.FtrlOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
-paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0))
+paddle.fluid.optimizer.RMSPropOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered'], varargs=None, keywords='kwargs', defaults=(0.95, 1e-06, 0.0, False))
paddle.fluid.optimizer.RMSPropOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.AdadeltaOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho'], varargs=None, keywords='kwargs', defaults=(1e-06, 0.95))
paddle.fluid.optimizer.AdadeltaOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
index 0bfff745493d069e948e6d277ec2bbfb0673a70b..7a99169849debcbc57d6f197b36c5045b211f3ef 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc
@@ -326,7 +326,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl(
ir::Graph &result = *graph;
for (auto &node : nodes) {
- if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) {
+ if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var());
}
}
@@ -583,18 +583,6 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
}
}
-bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
- const std::string &og,
- std::unordered_set *og_has_been_broadcast) const {
- bool is_pg_once =
- grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
- if (is_pg_once) {
- // Insert NCCL AllReduce Op
- og_has_been_broadcast->insert(og);
- }
- return is_pg_once;
-}
-
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
@@ -688,20 +676,6 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var;
}
-// Find the first occurence of `prev_op_name` and make current `op` depend
-// on it.
-void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
- const std::string &prev_op_name) const {
- for (auto &prev_op : result->Get(kGraphOps)) {
- if (prev_op->Name() == prev_op_name) {
- auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
- prev_op->AddOutput(dep_var);
- result->Get(kGraphDepVars).emplace(dep_var);
- op->AddInput(dep_var);
- }
- }
-}
-
void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int op_dev_id = -1;
diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h
index 7a6f238f9cf7af18cb10ea271e453fec1902c833..ac6d9c5a64cfde60f75c76dae0a30cc7d735e996 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_pass.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h
@@ -69,9 +69,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
std::vector FindDistTrainRecvVars(
const std::vector &nodes) const;
- void ConnectOp(ir::Graph *result, OpHandleBase *op,
- const std::string &prev_op_name) const;
-
void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const;
@@ -83,10 +80,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
- bool IsParameterGradientOnce(
- const std::string &og,
- std::unordered_set *og_has_been_broadcast) const;
-
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt
index bfc649017f19d67660bd11d590134cf56772bb27..f5235f70ad79616801110644999d511eeda33a32 100644
--- a/paddle/fluid/framework/ir/CMakeLists.txt
+++ b/paddle/fluid/framework/ir/CMakeLists.txt
@@ -1,20 +1,35 @@
+set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
+file(WRITE ${pass_file} "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n")
+file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
+function(pass_library TARGET)
+ set(options "")
+ set(oneValueArgs "")
+ set(multiValueArgs SRCS DEPS)
+ cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
+ cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass)
+ file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
+ set(PASS_LIBRARY ${TARGET} ${PASS_LIBRARY} PARENT_SCOPE)
+endfunction()
+
cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
-cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
-cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
-cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
-cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
-cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
-cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
-cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
+
+pass_library(graph_to_program_pass)
+pass_library(graph_viz_pass)
+pass_library(fc_fuse_pass)
+pass_library(attention_lstm_fuse_pass)
+pass_library(infer_clean_graph_pass)
+pass_library(fc_lstm_fuse_pass)
+pass_library(seq_concat_fc_fuse_pass)
+set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
-cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
+cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
index 0278ade6763ec614701674691797d766878a378e..bb52d7e498e55c02ddc2cd6d07ccccd51ce4edc5 100644
--- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc
@@ -13,13 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
-
#include
-
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
-#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace framework {
@@ -99,17 +96,13 @@ void FindWhileOp(Graph* graph) {
auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8);
-#define LINK_TO(node0, node1) \
- node0->outputs.push_back(node1); \
- node1->inputs.push_back(node0);
-
auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);
- LINK_TO(X, lstm_op);
- LINK_TO(cell_init, lstm_op);
- LINK_TO(hidden_init, lstm_op);
- LINK_TO(lstm_op, LSTMOUT);
+ IR_NODE_LINK_TO(X, lstm_op);
+ IR_NODE_LINK_TO(cell_init, lstm_op);
+ IR_NODE_LINK_TO(hidden_init, lstm_op);
+ IR_NODE_LINK_TO(lstm_op, LSTMOUT);
GraphSafeRemoveNodes(graph, marked_nodes);
}
diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc
index 513742bab69d465aac1bfb7bcef2fe89108c14a0..5a4ebd6f3de555acccd72c61bd377ffd8ce69780 100644
--- a/paddle/fluid/framework/ir/fc_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc
@@ -21,74 +21,26 @@ namespace paddle {
namespace framework {
namespace ir {
-bool VarOutLinksToOp(Node* node, const std::string& op_type) {
- for (auto* out : node->outputs) {
- if (out->IsOp() && out->Op()->Type() == op_type) {
- return true;
- }
- }
- return false;
-}
-
-void BuildFCPattern(PDPattern* pattern) {
- // Create Operators
- auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
- auto* elementwise_add_op =
- pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
- // Create variables
- // w
- auto* mul_weight_var = pattern->NewNode("mul_weight")
- ->AsInput()
- ->assert_is_op_nth_input("mul", "Y", 0);
- // x
- auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
- ->AsInput()
- ->assert_is_op_nth_input("mul", "X", 0);
- // intermediate variable, will be removed in the IR after fuse.
- auto* mul_out_var = pattern->NewNode("mul_out")
- ->AsIntermediate()
- ->assert_is_only_output_of_op("mul")
- ->assert_is_op_input("elementwise_add");
- // bias
- auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
- ->assert_is_op_input("elementwise_add")
- ->AsInput();
- // output
- auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
- ->AsOutput()
- ->assert_is_op_output("elementwise_add");
-
- mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
- elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
- .LinksTo({elementwise_add_out_var});
-}
-
-// Replace the node `from` in the links to `to`
-bool LinksReplace(std::vector* links, Node* from, Node* to) {
- for (auto*& n : *links) {
- if (n == from) {
- n = to;
- return true;
- }
- }
- return false;
-}
-
std::unique_ptr FCFusePass::ApplyImpl(
std::unique_ptr graph) const {
PADDLE_ENFORCE(graph.get());
- FusePassBase::Init("fc", graph.get());
+ FusePassBase::Init("fc_fuse", graph.get());
std::unordered_set nodes2delete;
GraphPatternDetector gpd;
- BuildFCPattern(gpd.mutable_pattern());
-
-#define GET_NODE(id) \
- PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
- "pattern has no Node called %s", #id); \
- auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
- PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
+ // BuildFCPattern(gpd.mutable_pattern());
+ auto* x = gpd.mutable_pattern()
+ ->NewNode("fc_fuse/x")
+ ->AsInput()
+ ->assert_is_op_input("mul", "X");
+ patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/);
+
+#define GET_NODE(id) \
+ PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
+ "pattern has no Node called %s", #id); \
+ auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
+ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
@@ -98,43 +50,33 @@ std::unique_ptr FCFusePass::ApplyImpl(
// scenerio.
// FC's fusion is simple, just op fuse, no need to process the
// parameters.
- GET_NODE(mul_tmp_var); // x
- GET_NODE(mul_weight); // Y
- GET_NODE(elementwise_add_tmpvar); // bias
- GET_NODE(elementwise_add_out); // Out
- GET_NODE(mul); // MUL op
- GET_NODE(elementwise_add); // ELEMENT_ADD op
- GET_NODE(mul_out); // tmp
+ GET_NODE(x); // x
+ GET_NODE(w); // Y
+ GET_NODE(fc_bias); // bias
+ GET_NODE(fc_out); // Out
+ GET_NODE(mul); // MUL op
+ GET_NODE(elementwise_add); // ELEMENT_ADD op
+ GET_NODE(mul_out); // tmp
#undef GET_NODE
// Create an FC Node.
OpDesc desc;
- std::string fc_x_in = mul_tmp_var->Name();
- std::string fc_Y_in = mul_weight->Name();
- std::string fc_bias_in = elementwise_add_tmpvar->Name();
- std::string fc_out = elementwise_add_out->Name();
+ std::string fc_x_in = x->Name();
+ std::string fc_Y_in = w->Name();
+ std::string fc_bias_in = fc_bias->Name();
+ std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector({fc_x_in}));
desc.SetInput("W", std::vector({fc_Y_in}));
desc.SetInput("Bias", std::vector({fc_bias_in}));
- desc.SetOutput("Out", std::vector({fc_out}));
+ desc.SetOutput("Out", std::vector({fc_out_out}));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
- fc_node->inputs =
- std::vector({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
- fc_node->outputs.push_back(elementwise_add_out);
-
- // Update link relatons
- PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
- PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
- PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
- elementwise_add, fc_node));
- PADDLE_ENFORCE(
- LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
+ GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
- // Drop old nodes
- graph->RemoveNode(mul);
- graph->RemoveNode(elementwise_add);
- graph->RemoveNode(mul_out); // tmp variable
+ IR_NODE_LINK_TO(x, fc_node);
+ IR_NODE_LINK_TO(w, fc_node);
+ IR_NODE_LINK_TO(fc_bias, fc_node);
+ IR_NODE_LINK_TO(fc_node, fc_out);
found_fc_count++;
};
diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
index c404a6c44ccea8287ddfad976889a9f80cf6bad9..0d69dfa79aa26940f8f56f84b35ffed34f29f703 100644
--- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include
#include "paddle/fluid/framework/lod_tensor.h"
@@ -87,15 +86,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
}
op_desc.SetInput("Bias", {new_bias_var});
}
-
#undef GET_NODE
+ // Create temp variables.
+ scope->Var(name_scope + "/BatchedInput.new")
+ ->GetMutable();
+ scope->Var(name_scope + "/BatchCellPreAct.new")
+ ->GetMutable();
+ scope->Var(name_scope + "/BatchedGate.new")
+ ->GetMutable();
+
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
- op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
+ op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"});
+ op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"});
+ op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
@@ -121,22 +129,18 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
#undef TMP_NEW
#undef TMP_NAME
-#define LINK_TO(a, b) \
- a->outputs.push_back(b); \
- b->inputs.push_back(a);
- LINK_TO(input_n, op);
- LINK_TO(weight_x_n, op);
- LINK_TO(weight_h_n, op);
- LINK_TO(bias_n, op);
- LINK_TO(op, hidden_n);
-#undef LINK_TO
+ IR_NODE_LINK_TO(input_n, op);
+ IR_NODE_LINK_TO(weight_x_n, op);
+ IR_NODE_LINK_TO(weight_h_n, op);
+ IR_NODE_LINK_TO(bias_n, op);
+ IR_NODE_LINK_TO(op, hidden_n);
return op;
};
int fusion_count{0};
- auto fc_no_bias_handler = [&](
- const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
+ auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
+ Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
@@ -157,21 +161,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if (with_fc_bias) {
GET_NODE(fc_bias);
+ GET_NODE(elementwise_add);
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
+ // Remove unneeded nodes.
+ std::unordered_set marked_nodes(
+ {mul_n, lstm_n, elementwise_add_n});
+ GraphSafeRemoveNodes(graph, marked_nodes);
} else {
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1);
+ // Remove unneeded nodes.
+ std::unordered_set marked_nodes({mul_n, lstm_n});
+ GraphSafeRemoveNodes(graph, marked_nodes);
}
#undef GET_NODE
- // Remove unneeded nodes.
- std::unordered_set marked_nodes({mul_n, lstm_n});
-
- GraphSafeRemoveNodes(graph, marked_nodes);
-
++fusion_count;
};
- gpd(graph, fc_no_bias_handler);
+ gpd(graph, handler);
return fusion_count;
}
diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
index 5a6687872eb3ab4a032227fda9ff0e7f5254670b..3ee32c63a46fcc34bdccd1e14d4bbaf9668c49e9 100644
--- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
+++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#pragma once
+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc
index f651ab635eadc9f248964e91dceebf3aa9c42926..731b89423354532f684e19305dfa87e8eb75d4b1 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc
@@ -73,7 +73,6 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) {
- LOG(INFO) << "Mark failed";
return;
}
@@ -86,7 +85,7 @@ void GraphPatternDetector::operator()(Graph* graph,
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) {
- LOG(INFO) << "optimizing #" << id++ << " subgraph";
+ VLOG(3) << "optimizing #" << id++ << " subgraph";
handler(g, graph);
}
}
@@ -111,6 +110,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
return false;
}
}
+ for (auto& item : pdnodes2nodes_) {
+ for (auto& n : item.second) {
+ GetMarkedNodes(const_cast(&graph)).insert(n);
+ }
+ }
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty();
@@ -278,7 +282,7 @@ void GraphPatternDetector::RemoveOverlappedMatch(
for (const auto& subgraph : *subgraphs) {
bool valid = true;
for (auto& item : subgraph) {
- if (node_set.count(item.second)) {
+ if (item.first->IsIntermediate() && node_set.count(item.second)) {
valid = false;
break;
}
@@ -334,22 +338,22 @@ PDNode& PDNode::LinksFrom(const std::vector& others) {
}
PDNode* PDNode::assert_is_op() {
- asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); });
+ asserts_.emplace_back([](Node* x) { return x && x->IsOp(); });
return this;
}
PDNode* PDNode::assert_is_op(const std::string& op_type) {
- asserts_.emplace_back([this, op_type](Node* x) {
+ asserts_.emplace_back([op_type](Node* x) {
return x && x->IsOp() && x->Op()->Type() == op_type;
});
return this;
}
PDNode* PDNode::assert_is_var() {
- asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); });
+ asserts_.emplace_back([](Node* x) { return x && x->IsVar(); });
return this;
}
PDNode* PDNode::assert_var_not_persistable() {
assert_is_var();
- asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); });
+ asserts_.emplace_back([](Node* x) { return !x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_persistable_var() {
@@ -491,14 +495,16 @@ void GraphSafeRemoveNodes(Graph* graph,
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast(node)->inputs.erase(it);
- } else
+ } else {
it++;
+ }
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast(node)->outputs.erase(it);
- } else
+ } else {
it++;
+ }
}
}
}
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h
index 024ce8ce55616cc5e0eaced4a27a6e1fb004af2c..eacea1750f6f1e86a8fe79637c3bd757a7275398 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector.h
+++ b/paddle/fluid/framework/ir/graph_pattern_detector.h
@@ -19,6 +19,9 @@
#endif
#include
+#include
+#include
+#include
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
@@ -245,6 +248,8 @@ class GraphPatternDetector {
void UniquePatterns(std::vector* subgraphs);
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
+ // The intermediate PDNodes will be removed, so can't shared by multiple
+ // patterns.
void RemoveOverlappedMatch(std::vector* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
@@ -295,6 +300,10 @@ PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x);
} // namespace patterns
+#define IR_NODE_LINK_TO(a, b) \
+ a->outputs.push_back(b); \
+ b->inputs.push_back(a);
+
} // namespace ir
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
index 7e5c86b033a7c69a306491cf4bf8d099018c5f19..6c466fb21fb46e09961dc874e9e39655f83d17c6 100644
--- a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
+++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc
@@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3");
},
"OP0");
- auto* any_var = x.mutable_pattern()->NewNode(
- [](Node* node) { return node->IsVar(); }, "VAR");
+ auto* any_var = x.mutable_pattern()
+ ->NewNode([](Node* node) { return node->IsVar(); }, "VAR")
+ ->AsIntermediate();
auto* any_op1 = x.mutable_pattern()->NewNode(
[](Node* node) { return node->IsOp(); }, "OP1");
diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc
index 4c7ffe69e933de3d52c8f762a1eeb73de17e0561..31ed98db72c8fd4af8c970861d386687962001ce 100644
--- a/paddle/fluid/framework/ir/graph_viz_pass.cc
+++ b/paddle/fluid/framework/ir/graph_viz_pass.cc
@@ -50,20 +50,37 @@ std::unique_ptr GraphVizPass::ApplyImpl(
Dot dot;
- std::vector op_attrs({Dot::Attr("style", "filled"),
- Dot::Attr("shape", "box"),
- Dot::Attr("fillcolor", "red")});
- std::vector var_attrs({Dot::Attr("style", "filled,rounded"),
- // Dot::Attr("shape", "diamond"),
- Dot::Attr("fillcolor", "yellow")});
-
- std::vector marked_op_attrs({Dot::Attr("style", "filled"),
- Dot::Attr("shape", "box"),
- Dot::Attr("fillcolor", "lightgray")});
- std::vector marked_var_attrs(
- {Dot::Attr("style", "filled,rounded"),
- // Dot::Attr("shape", "diamond"),
- Dot::Attr("fillcolor", "lightgray")});
+ const std::vector op_attrs({
+ Dot::Attr("style", "rounded,filled,bold"), //
+ Dot::Attr("shape", "box"), //
+ Dot::Attr("color", "#303A3A"), //
+ Dot::Attr("fontcolor", "#ffffff"), //
+ Dot::Attr("width", "1.3"), //
+ Dot::Attr("height", "0.84"), //
+ Dot::Attr("fontname", "Arial"), //
+ });
+ const std::vector arg_attrs({
+ Dot::Attr("shape", "box"), //
+ Dot::Attr("style", "rounded,filled,bold"), //
+ Dot::Attr("fontname", "Arial"), //
+ Dot::Attr("fillcolor", "#999999"), //
+ Dot::Attr("color", "#dddddd"), //
+ });
+
+ const std::vector param_attrs({
+ Dot::Attr("shape", "box"), //
+ Dot::Attr("style", "rounded,filled,bold"), //
+ Dot::Attr("fontname", "Arial"), //
+ Dot::Attr("color", "#148b97"), //
+ Dot::Attr("fontcolor", "#ffffff"), //
+ });
+
+ const std::vector marked_op_attrs(
+ {Dot::Attr("style", "rounded,filled,bold"), Dot::Attr("shape", "box"),
+ Dot::Attr("fillcolor", "yellow")});
+ const std::vector marked_var_attrs(
+ {Dot::Attr("style", "filled,rounded"), Dot::Attr("shape", "box"),
+ Dot::Attr("fillcolor", "yellow")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
// Create nodes
@@ -74,9 +91,17 @@ std::unique_ptr GraphVizPass::ApplyImpl(
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
- decltype(op_attrs) attr =
- marked_nodes.count(n) ? marked_var_attrs : var_attrs;
- dot.AddNode(node_id, attr, node_id);
+ decltype(op_attrs)* attr;
+ if (marked_nodes.count(n)) {
+ attr = &marked_var_attrs;
+ } else if (const_cast(n)->Var() &&
+ const_cast(n)->Var()->Persistable()) {
+ attr = ¶m_attrs;
+ } else {
+ attr = &arg_attrs;
+ }
+
+ dot.AddNode(node_id, *attr, node_id);
}
node2dot[n] = node_id;
}
diff --git a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc
index f885567da1965b997b2063e06c839af95b43e1e1..7713ed1eab88ee4fa16d52e7425075ae66f721a3 100644
--- a/paddle/fluid/framework/ir/infer_clean_graph_pass.cc
+++ b/paddle/fluid/framework/ir/infer_clean_graph_pass.cc
@@ -13,42 +13,41 @@
// limitations under the License.
#include
+#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
-#include "paddle/fluid/framework/ir/pass.h"
+#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
-class InferCleanGraphPass : public Pass {
+class InferCleanGraphPass : public FusePassBase {
public:
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr ApplyImpl(std::unique_ptr graph) const {
+ FusePassBase::Init("original_graph", graph.get());
PADDLE_ENFORCE(graph.get());
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
- std::unordered_set invalid_nodes;
+ std::unordered_set invalid_nodes;
+ int valid_op = 0;
for (auto* node : graph->Nodes()) {
if (is_valid_node(node)) {
invalid_nodes.insert(node);
+ } else if (node->IsOp()) {
+ // Collect all the operators to help tracking number of operators.
+ ++valid_op;
}
}
- // remove nodes from the graph.
- for (auto* node : invalid_nodes) {
- graph->RemoveNode(node);
- }
+ GraphSafeRemoveNodes(graph.get(), invalid_nodes);
- // clean edges.
- for (auto* node : graph->Nodes()) {
- CleanEdges(&node->inputs, invalid_nodes);
- CleanEdges(&node->outputs, invalid_nodes);
- }
+ AddStatis(valid_op);
return graph;
}
diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
index a776a898a5ee13b4dde12460dce71433268fb9d4..e1a441d09aaa3647c4b2a582210a2c7e2b64e0da 100644
--- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc
@@ -219,16 +219,13 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl(
op_desc.SetAttr("fc_activation", act->Op()->Type());
auto* op_node = graph->CreateOpNode(&op_desc);
-// Add links
-#define NODE_LINKS(a, b) \
- a->outputs.push_back(b); \
- b->inputs.push_back(a);
- NODE_LINKS(fc_w, op_node);
- NODE_LINKS(fc_bias, op_node);
- NODE_LINKS(concat_in0, op_node);
- NODE_LINKS(sequence_expand0_in, op_node);
- NODE_LINKS(sequence_expand1_in, op_node);
- NODE_LINKS(op_node, fc_out);
+ // Add links
+ IR_NODE_LINK_TO(fc_w, op_node);
+ IR_NODE_LINK_TO(fc_bias, op_node);
+ IR_NODE_LINK_TO(concat_in0, op_node);
+ IR_NODE_LINK_TO(sequence_expand0_in, op_node);
+ IR_NODE_LINK_TO(sequence_expand1_in, op_node);
+ IR_NODE_LINK_TO(op_node, fc_out);
// Clean nodes.
std::unordered_set marked_nodes;
@@ -241,7 +238,6 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl(
marked_nodes.erase(sequence_expand0_in);
marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out);
-
GraphSafeRemoveNodes(graph, marked_nodes);
});
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index 86392078b356df774fbc47aed9214e9f10fe33be..2006e3b24f71d0ae32b4e2ae34f1a1e4d3a82f91 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -10,7 +10,7 @@ set(FLUID_CORE_MODULES proto_desc memory lod_tensor executor)
# TODO(panyx0718): Should this be called paddle_fluid_inference_api_internal?
cc_library(paddle_fluid_api
SRCS io.cc
- DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB} graph_to_program_pass)
+ DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
@@ -22,7 +22,7 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api)
#endif()
# Create static library
-cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api)
+cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api analysis_predictor)
if(NOT APPLE)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym")
@@ -32,6 +32,7 @@ endif()
# Create shared library
cc_library(paddle_fluid_shared SHARED
SRCS io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc
DEPS ${fluid_modules} paddle_fluid_api)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt
index cc0dd0d492d42e9552c9ce081e268330599104f0..226645058e85da55b47e26efe5a199f50aef3847 100644
--- a/paddle/fluid/inference/analysis/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/CMakeLists.txt
@@ -33,7 +33,7 @@ function (inference_analysis_test TARGET)
endif()
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
- DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
+ DEPS analysis pass ${GLOB_PASS_LIB} ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt} ${analysis_test_ARGS})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING)
@@ -56,25 +56,13 @@ if (NOT EXISTS ${DITU_INSTALL_DIR} AND WITH_TESTING)
endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
- EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
- analysis_predictor
- # ir
- fc_fuse_pass
- fc_lstm_fuse_pass
- seq_concat_fc_fuse_pass
- graph_viz_pass
- infer_clean_graph_pass
- graph_pattern_detector
- infer_clean_graph_pass
- attention_lstm_fuse_pass
- paddle_inference_api
- pass
+ EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
- --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
+ --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
-inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
-inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
+inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
+inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
@@ -86,7 +74,7 @@ inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
set(CHINESE_NER_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz")
set(CHINESE_NER_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz")
set(CHINESE_NER_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/chinese_ner" CACHE PATH "Chinese ner model and data root." FORCE)
-if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING)
+if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_MODEL_URL} "chinese_ner_model.tar.gz")
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz")
endif()
@@ -99,7 +87,7 @@ inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
set(LAC_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/lac_model.tar.gz")
set(LAC_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/lac_data.txt.tar.gz")
set(LAC_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/lac" CACHE PATH "LAC model and data root." FORCE)
-if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING)
+if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_MODEL_URL} "lac_model.tar.gz")
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_DATA_URL} "lac_data.txt.tar.gz")
endif()
@@ -108,3 +96,15 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api
ARGS --infer_model=${LAC_INSTALL_DIR}/model
--infer_data=${LAC_INSTALL_DIR}/data.txt)
+
+
+set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz")
+set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE)
+
+if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
+ inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz")
+endif()
+
+inference_analysis_test(test_text_classification SRCS analyzer_text_classification_tester.cc
+ EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
+ ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta)
diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc
index 192ac2daa6a78efec6db19870f71e80593c62da9..1fd884435d173800563ea37809003ed3aee16c7c 100644
--- a/paddle/fluid/inference/analysis/analyzer.cc
+++ b/paddle/fluid/inference/analysis/analyzer.cc
@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include
+#include
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
@@ -41,20 +42,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
public:
DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs.
- LOG(INFO)
- << "-----------------------------------------------------------------";
- if (FLAGS_IA_enable_ir) {
- AddPass("fluid-to-ir-pass", new FluidToIrPass);
- } else {
+ if (!FLAGS_IA_enable_ir) {
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
+ } else {
+ AddPass("fluid-to-ir-pass", new FluidToIrPass);
}
TryAddTensorRtPass();
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_IA_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
- LOG(INFO)
- << "-----------------------------------------------------------------";
}
std::string repr() const override { return "dfg-pass-manager"; }
@@ -101,19 +98,15 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
- // Ugly support fluid-to-ir-pass
- argument->Set(kFluidToIrPassesAttr,
- new std::vector({
- // Manual update the passes here.
- "graph_viz_pass", //
- "infer_clean_graph_pass", "graph_viz_pass", //
- "attention_lstm_fuse_pass", "graph_viz_pass", //
- "fc_lstm_fuse_pass", "graph_viz_pass", //
- "mul_lstm_fuse_pass", "graph_viz_pass", //
- "seq_concat_fc_fuse_pass", "graph_viz_pass", //
- "fc_fuse_pass", "graph_viz_pass" //
-
- }));
+ std::vector passes;
+ for (auto& pass : all_ir_passes_) {
+ if (!disabled_ir_passes_.count(pass)) {
+ passes.push_back(pass);
+ passes.push_back("graph_viz_pass"); // add graphviz for debug.
+ }
+ }
+ passes.push_back("graph_viz_pass");
+ argument->Set(kFluidToIrPassesAttr, new std::vector(passes));
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
@@ -122,6 +115,11 @@ void Analyzer::Run(Argument* argument) {
}
}
+Analyzer& Analyzer::DisableIrPasses(const std::vector& passes) {
+ disabled_ir_passes_.insert(passes.begin(), passes.end());
+ return *this;
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h
index 2e107c82dd50d5cf22797f4c82e69d302514f955..3fdd2b9ec7537c891a04efb3ca9a1d45075ffa5e 100644
--- a/paddle/fluid/inference/analysis/analyzer.h
+++ b/paddle/fluid/inference/analysis/analyzer.h
@@ -36,16 +36,10 @@ limitations under the License. */
*/
#include
+#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
-// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
-// flag if not available.
-DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
-DECLARE_string(IA_graphviz_log_root);
-DECLARE_string(IA_output_storage_path);
-DECLARE_bool(IA_enable_ir);
-
namespace paddle {
namespace inference {
namespace analysis {
@@ -57,7 +51,26 @@ class Analyzer : public OrderedRegistry {
void Run(Argument* argument);
+ Analyzer& DisableIrPasses(const std::vector& passes);
+
DISABLE_COPY_AND_ASSIGN(Analyzer);
+
+ private:
+ // All avaiable IR passes.
+ // The bigger fuse comes first, so that the small operators prefer to be
+ // merged in a larger fuse op. The small fusion will not break the pattern of
+ // larger fusion.
+ const std::vector all_ir_passes_{{
+ // Manual update the passes here.
+ "infer_clean_graph_pass", //
+ "attention_lstm_fuse_pass", //
+ "fc_lstm_fuse_pass", //
+ "mul_lstm_fuse_pass", //
+ "seq_concat_fc_fuse_pass", //
+ "fc_fuse_pass", //
+ }};
+
+ std::unordered_set disabled_ir_passes_;
};
} // namespace analysis
diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc
index ec1f3979a74bd86ee7402bca441e95d3d177d113..4cf26d3c70eafd951d14c26335416ec2c71c001d 100644
--- a/paddle/fluid/inference/analysis/analyzer_tester.cc
+++ b/paddle/fluid/inference/analysis/analyzer_tester.cc
@@ -16,19 +16,21 @@
#include
#include
+#include // NOLINT
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
+#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
-#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
DEFINE_int32(batch_size, 10, "batch size.");
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
+DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
namespace paddle {
namespace inference {
@@ -219,39 +221,6 @@ void PrepareInputs(std::vector *input_slots, DataRecord *data,
}
}
-std::string DescribeTensor(const PaddleTensor &tensor) {
- std::stringstream os;
- os << "Tensor [" << tensor.name << "]\n";
- os << " - type: ";
- switch (tensor.dtype) {
- case PaddleDType::FLOAT32:
- os << "float32";
- break;
- case PaddleDType::INT64:
- os << "int64";
- break;
- default:
- os << "unset";
- }
- os << '\n';
-
- os << " - shape: " << to_string(tensor.shape) << '\n';
- os << " - lod: ";
- for (auto &l : tensor.lod) {
- os << to_string(l) << "; ";
- }
- os << "\n";
- os << " - data: ";
-
- int dim = std::accumulate(tensor.shape.begin(), tensor.shape.end(), 1,
- [](int a, int b) { return a * b; });
- for (int i = 0; i < dim; i++) {
- os << static_cast(tensor.data.data())[i] << " ";
- }
- os << '\n';
- return os.str();
-}
-
} // namespace
const float ditu_rnn_target_data[] = {
@@ -265,57 +234,97 @@ const float ditu_rnn_target_data[] = {
10.7286, 12.0595, 10.6672, 0, 0, 0, 0, 0,
93.5771, 3.84641, 0, 0, 0, 0, 0, 0,
169.426, 0, 0, 0, 0, 0, 0, 0};
+void CompareResult(const std::vector &outputs,
+ const std::vector &base_outputs) {
+ PADDLE_ENFORCE_GT(outputs.size(), 0);
+ PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
+ for (size_t i = 0; i < outputs.size(); i++) {
+ auto &out = outputs[i];
+ auto &base_out = base_outputs[i];
+ size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
+ [](int a, int b) { return a * b; });
+ size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
+ 1, [](int a, int b) { return a * b; });
+ PADDLE_ENFORCE_EQ(size, size1);
+ PADDLE_ENFORCE_GT(size, 0);
+ float *data = static_cast(out.data.data());
+ float *base_data = static_cast(base_out.data.data());
+ for (size_t i = 0; i < size; i++) {
+ EXPECT_NEAR(data[i], base_data[i], 1e-3);
+ }
+ }
+}
// Test with a really complicate model.
-void TestDituRNNPrediction(const std::string &model_path,
- const std::string &data_path, int batch_size,
- bool use_analysis, bool activate_ir,
- int num_times = 1) {
- NativeConfig config;
+void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
+ int num_threads) {
+ AnalysisConfig config;
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
+ config.enable_ir_optim = activate_ir;
+ PADDLE_ENFORCE(config.ir_mode ==
+ AnalysisConfig::IrPassMode::kExclude); // default
+ config.ir_passes.clear(); // Do not exclude any pass.
+ int batch_size = FLAGS_batch_size;
+ int num_times = FLAGS_repeat;
auto base_predictor =
CreatePaddlePredictor(config);
auto predictor =
- CreatePaddlePredictor(config);
+ CreatePaddlePredictor(
+ config);
std::vector input_slots;
- DataRecord data(data_path, batch_size);
+ DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size);
// Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size);
std::vector outputs, base_outputs;
base_predictor->Run(input_slots, &base_outputs);
- Timer timer;
- timer.tic();
- for (int i = 0; i < num_times; i++) {
- predictor->Run(input_slots, &outputs);
- }
LOG(INFO) << "===========profile result===========";
- LOG(INFO) << "batch_size: " << batch_size << ", repeat: " << num_times
- << ", latency: " << timer.toc() / num_times << "ms";
- LOG(INFO) << "=====================================";
-
- PADDLE_ENFORCE_GT(outputs.size(), 0);
- PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
- for (size_t i = 0; i < outputs.size(); i++) {
- auto &out = outputs[i];
- auto &base_out = base_outputs[i];
- size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
- [](int a, int b) { return a * b; });
- size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
- 1, [](int a, int b) { return a * b; });
- PADDLE_ENFORCE_EQ(size, size1);
- PADDLE_ENFORCE_GT(size, 0);
- float *data = static_cast(out.data.data());
- float *base_data = static_cast(base_out.data.data());
- for (size_t j = 0; j < size; j++) {
- EXPECT_NEAR(data[j], base_data[j], 1e-3);
+ if (num_threads == 1) {
+ // Prepare inputs.
+ Timer timer;
+ timer.tic();
+ for (int i = 0; i < num_times; i++) {
+ predictor->Run(input_slots, &outputs);
+ }
+ PrintTime(batch_size, num_times, 1, 0, timer.toc() / num_times);
+ CompareResult(outputs, base_outputs);
+ } else {
+ std::vector threads;
+ std::vector> predictors;
+ // TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled
+ // because AttentionLSTM's hard code nodeid will be damanged.
+ for (int tid = 0; tid < num_threads; ++tid) {
+ predictors.emplace_back(
+ CreatePaddlePredictor(
+ config));
+ }
+ for (int tid = 0; tid < num_threads; ++tid) {
+ threads.emplace_back([&, tid]() {
+ // Each thread should have local input_slots and outputs.
+ std::vector input_slots;
+ DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size);
+ PrepareInputs(&input_slots, &data, batch_size);
+ std::vector outputs;
+ Timer timer;
+ timer.tic();
+ for (int i = 0; i < num_times; i++) {
+ predictors[tid]->Run(input_slots, &outputs);
+ }
+ PrintTime(batch_size, num_times, num_threads, tid,
+ timer.toc() / num_times);
+ CompareResult(outputs, base_outputs);
+ });
+ }
+ for (int i = 0; i < num_threads; ++i) {
+ threads[i].join();
}
}
+ LOG(INFO) << "=====================================";
if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor =
@@ -327,40 +336,45 @@ void TestDituRNNPrediction(const std::string &model_path,
LOG(INFO) << "fused " << item.first << " " << item.second;
}
- ASSERT_TRUE(fuse_statis.count("fc"));
- EXPECT_EQ(fuse_statis.at("fc"), 1);
- EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1);
- }
-}
+ int num_ops = 0;
+ for (auto &node :
+ analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
+ if (node->IsFunction()) {
+ ++num_ops;
+ }
+ }
+ LOG(INFO) << "has num ops: " << num_ops;
-// Directly infer with the original model.
-TEST(Analyzer, DituRNN_without_analysis) {
- TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- FLAGS_batch_size, false, false, FLAGS_repeat);
+ ASSERT_TRUE(fuse_statis.count("fc_fuse"));
+ EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
+ EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
+ EXPECT_EQ(num_ops,
+ 13); // After graph optimization, only 13 operators exists.
+ }
}
-// Inference with the original model with the analysis turned on, the analysis
-// module will transform the program to a data flow graph.
-TEST(Analyzer, DituRNN_with_analysis) {
- LOG(INFO) << "ditu rnn with analysis";
- TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- FLAGS_batch_size, true, false, FLAGS_repeat);
+// Inference with analysis and IR, easy for profiling independently.
+TEST(Analyzer, DituRNN) {
+ TestDituRNNPrediction(true, true, FLAGS_num_threads);
}
-// Inference with analysis and IR. The IR module will fuse some large kernels.
-TEST(Analyzer, DituRNN_with_analysis_with_IR) {
- LOG(INFO) << "ditu rnn with analysis and IR fuse";
- TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
- FLAGS_batch_size, true, true, FLAGS_repeat);
+// Other unit-tests of DituRNN, test different options of use_analysis,
+// activate_ir and multi-threads.
+TEST(Analyzer, DituRNN_tests) {
+ int num_threads[2] = {1, 4};
+ for (auto i : num_threads) {
+ // Directly infer with the original model.
+ TestDituRNNPrediction(false, false, i);
+ // Inference with the original model with the analysis turned on, the
+ // analysis
+ // module will transform the program to a data flow graph.
+ TestDituRNNPrediction(true, false, i);
+ // Inference with analysis and IR. The IR module will fuse some large
+ // kernels.
+ TestDituRNNPrediction(true, true, i);
+ }
}
} // namespace analysis
} // namespace inference
} // namespace paddle
-
-USE_PASS(fc_fuse_pass);
-USE_PASS(seq_concat_fc_fuse_pass);
-USE_PASS(fc_lstm_fuse_pass);
-USE_PASS(graph_viz_pass);
-USE_PASS(infer_clean_graph_pass);
-USE_PASS(attention_lstm_fuse_pass);
diff --git a/paddle/fluid/inference/analysis/analyzer_text_classification_tester.cc b/paddle/fluid/inference/analysis/analyzer_text_classification_tester.cc
new file mode 100644
index 0000000000000000000000000000000000000000..265e814acd594d6185251cbaa4d6880bb9ee7405
--- /dev/null
+++ b/paddle/fluid/inference/analysis/analyzer_text_classification_tester.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/analysis/analyzer.h"
+#include
+#include // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
+#include
+#include "paddle/fluid/framework/ir/pass.h"
+#include "paddle/fluid/inference/analysis/ut_helper.h"
+#include "paddle/fluid/inference/api/paddle_inference_api.h"
+#include "paddle/fluid/inference/api/paddle_inference_pass.h"
+#include "paddle/fluid/inference/api/timer.h"
+
+DEFINE_string(infer_model, "", "Directory of the inference model.");
+DEFINE_string(infer_data, "", "Path of the dataset.");
+DEFINE_int32(batch_size, 1, "batch size.");
+DEFINE_int32(repeat, 1, "How many times to repeat run.");
+
+namespace paddle {
+
+template
+std::string to_string(const std::vector &vec) {
+ std::stringstream ss;
+ for (const auto &c : vec) {
+ ss << c << " ";
+ }
+ return ss.str();
+}
+
+void PrintTime(const double latency, const int bs, const int repeat) {
+ LOG(INFO) << "===========profile result===========";
+ LOG(INFO) << "batch_size: " << bs << ", repeat: " << repeat
+ << ", avg latency: " << latency / repeat << "ms";
+ LOG(INFO) << "=====================================";
+}
+
+void Main(int batch_size) {
+ // Three sequence inputs.
+ std::vector input_slots(1);
+ // one batch starts
+ // data --
+ int64_t data0[] = {0, 1, 2};
+ for (auto &input : input_slots) {
+ input.data.Reset(data0, sizeof(data0));
+ input.shape = std::vector({3, 1});
+ // dtype --
+ input.dtype = PaddleDType::INT64;
+ // LoD --
+ input.lod = std::vector>({{0, 3}});
+ }
+
+ // shape --
+ // Create Predictor --
+ AnalysisConfig config;
+ config.model_dir = FLAGS_infer_model;
+ config.use_gpu = false;
+ config.enable_ir_optim = true;
+ config.ir_passes.push_back("fc_lstm_fuse_pass");
+ auto predictor =
+ CreatePaddlePredictor(
+ config);
+
+ inference::Timer timer;
+ double sum = 0;
+ std::vector output_slots;
+ for (int i = 0; i < FLAGS_repeat; i++) {
+ timer.tic();
+ CHECK(predictor->Run(input_slots, &output_slots));
+ sum += timer.toc();
+ }
+ PrintTime(sum, batch_size, FLAGS_repeat);
+
+ // Get output
+ LOG(INFO) << "get outputs " << output_slots.size();
+
+ for (auto &output : output_slots) {
+ LOG(INFO) << "output.shape: " << to_string(output.shape);
+ // no lod ?
+ CHECK_EQ(output.lod.size(), 0UL);
+ LOG(INFO) << "output.dtype: " << output.dtype;
+ std::stringstream ss;
+ for (int i = 0; i < 5; i++) {
+ ss << static_cast(output.data.data())[i] << " ";
+ }
+ LOG(INFO) << "output.data summary: " << ss.str();
+ // one batch ends
+ }
+}
+
+TEST(text_classification, basic) { Main(FLAGS_batch_size); }
+
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/flags.h b/paddle/fluid/inference/analysis/flags.h
new file mode 100644
index 0000000000000000000000000000000000000000..717e543f01dfa071865a5c14c0b7679e65239daf
--- /dev/null
+++ b/paddle/fluid/inference/analysis/flags.h
@@ -0,0 +1,22 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include
+
+// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
+// flag if not available.
+DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
+DECLARE_string(IA_graphviz_log_root);
+DECLARE_string(IA_output_storage_path);
+DECLARE_bool(IA_enable_ir);
diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
index 6731b1f759363eec5dd8645783212a72ace67b2f..3086085710d6e850ed27e82d2323690dfdd3ef19 100644
--- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
+++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h
@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
+#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
@@ -85,9 +86,11 @@ class FluidToIrPass final : public DataFlowGraphPass {
new Scope *(&argument_->Get(ir::kParamScopeAttr)));
}
- const auto &ir_passes_to_apply =
- argument_->Get>(kFluidToIrPassesAttr);
- ir_passes.Apply(ir_passes_to_apply);
+ if (FLAGS_IA_enable_ir) {
+ const auto &ir_passes_to_apply =
+ argument_->Get>(kFluidToIrPassesAttr);
+ ir_passes.Apply(ir_passes_to_apply);
+ }
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
index 6a13c60e7b2ebf645b12d5ddf83ef6ab3a2e83bd..367c25805d05f8d10fb8341158760ac6356a5c48 100644
--- a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
+++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc
@@ -16,6 +16,7 @@
#include
#include "paddle/fluid/inference/analysis/ut_helper.h"
+#include "paddle/fluid/inference/api/paddle_inference_pass.h"
namespace paddle {
namespace inference {
@@ -33,10 +34,3 @@ TEST(FluidToIrPass, Test) {
} // namespace analysis
} // namespace inference
} // namespace paddle
-
-USE_PASS(graph_viz_pass);
-USE_PASS(infer_clean_graph_pass);
-USE_PASS(attention_lstm_fuse_pass);
-USE_PASS(fc_lstm_fuse_pass);
-USE_PASS(seq_concat_fc_fuse_pass);
-USE_PASS(fc_fuse_pass);
diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt
index adfe4392448557a30cd834022b9a5d21d9086b95..6b8278a0395c9ae71e32337d9735409de7ba0c96 100644
--- a/paddle/fluid/inference/api/CMakeLists.txt
+++ b/paddle/fluid/inference/api/CMakeLists.txt
@@ -18,10 +18,7 @@ if(APPLE)
endif(APPLE)
-set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager
- graph_viz_pass fc_fuse_pass
- infer_clean_graph_pass
- )
+set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager ${GLOB_PASS_LIB})
if(WITH_GPU AND TENSORRT_FOUND)
set(inference_deps ${inference_deps} paddle_inference_tensorrt_subgraph_engine)
@@ -47,7 +44,7 @@ function(inference_api_test TARGET_NAME)
endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
-cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
+cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api analysis)
cc_test(test_paddle_inference_api
SRCS api_tester.cc
diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc
index 33862232bdaae817b9ca72879605386c32ed3e8b..79eeea88ea83ad862b5e2ac1390dae377b676685 100644
--- a/paddle/fluid/inference/api/analysis_predictor.cc
+++ b/paddle/fluid/inference/api/analysis_predictor.cc
@@ -14,10 +14,13 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include
+#include
+#include
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
+#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
@@ -27,10 +30,11 @@ bool AnalysisPredictor::Init(
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
+ LOG(WARNING) << "ir optimize only supports CPU currently";
+ config_.enable_ir_optim = false;
} else {
place_ = paddle::platform::CPUPlace();
}
- PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
@@ -72,7 +76,7 @@ bool AnalysisPredictor::Init(
void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
- FLAGS_IA_enable_ir = true;
+ FLAGS_IA_enable_ir = config_.enable_ir_optim;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
@@ -89,24 +93,26 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
- Analyzer().Run(&argument_);
+ PADDLE_ENFORCE(config_.ir_mode == AnalysisConfig::IrPassMode::kExclude,
+ "Only kExclude is supported yet.");
+ Analyzer().DisableIrPasses(config_.ir_passes).Run(&argument_);
+
CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor";
- // LOG(INFO) << "transformed_parogram_desc " <<
- // argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument_.transformed_program_desc));
- PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr));
- // Update scope.
- scope_.reset(
- argument_.Release(framework::ir::kParamScopeAttr));
- LOG(INFO) << "optimize end ==";
+ if (argument_.Has(framework::ir::kParamScopeAttr)) {
+ // Update scope.
+ scope_.reset(
+ argument_.Release(framework::ir::kParamScopeAttr));
+ }
+ LOG(INFO) << "== optimize end ==";
}
template <>
std::unique_ptr CreatePaddlePredictor<
- NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
- VLOG(3) << "create NativePredictor";
+ AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig& config) {
+ VLOG(3) << "create AnalysisConfig";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
@@ -133,7 +139,3 @@ std::unique_ptr CreatePaddlePredictor<
}
} // namespace paddle
-
-USE_PASS(fc_fuse_pass);
-USE_PASS(graph_viz_pass);
-USE_PASS(infer_clean_graph_pass);
diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h
index e32b6185f6044ab3577bde0a8f8dcf2391688aa8..e53925366e9214cd60422efe56884751297c15e5 100644
--- a/paddle/fluid/inference/api/analysis_predictor.h
+++ b/paddle/fluid/inference/api/analysis_predictor.h
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include
+#include
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
@@ -28,7 +30,7 @@ using framework::proto::ProgramDesc;
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
- explicit AnalysisPredictor(const NativeConfig& config)
+ explicit AnalysisPredictor(const AnalysisConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr& parent_scope);
@@ -44,7 +46,7 @@ class AnalysisPredictor : public NativePaddlePredictor {
Argument& analysis_argument() { return argument_; }
private:
- NativeConfig config_;
+ AnalysisConfig config_;
Argument argument_;
};
diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc
index 38b11d9113e4b03f8365b969009f7a385a683a70..bd9b4b1a814f995e3979105f5b9830b95fd8ea7d 100644
--- a/paddle/fluid/inference/api/api_impl.cc
+++ b/paddle/fluid/inference/api/api_impl.cc
@@ -176,7 +176,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs,
framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feeds_.size()) {
- LOG(ERROR) << "wrong feed input size.";
+ LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
+ << inputs.size();
return false;
}
for (size_t i = 0; i < inputs.size(); ++i) {
diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh
index 7824ef2649af81a2390ff3bc537eb7c93c70e402..0f7d541c5edfc62e80cf50f83b491f06dcb42644 100755
--- a/paddle/fluid/inference/api/demo_ci/run.sh
+++ b/paddle/fluid/inference/api/demo_ci/run.sh
@@ -14,7 +14,7 @@ else
fi
PREFIX=inference-vis-demos%2F
-URL_ROOT=http://paddlemodels.bj.bcebos.com/${PREFIX}
+URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX}
# download vis_demo data
function download() {
diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h
index bdc9a15d543818da94ac2acf34ecabbbbae3291e..2c2ac656e8005369bb0e9033236a431cb09caa15 100644
--- a/paddle/fluid/inference/api/helper.h
+++ b/paddle/fluid/inference/api/helper.h
@@ -14,8 +14,10 @@
#pragma once
+#include
#include
#include
+#include
#include
#include
#include
@@ -87,5 +89,45 @@ static void TensorAssignData(PaddleTensor *tensor,
}
}
+std::string DescribeTensor(const PaddleTensor &tensor) {
+ std::stringstream os;
+ os << "Tensor [" << tensor.name << "]\n";
+ os << " - type: ";
+ switch (tensor.dtype) {
+ case PaddleDType::FLOAT32:
+ os << "float32";
+ break;
+ case PaddleDType::INT64:
+ os << "int64";
+ break;
+ default:
+ os << "unset";
+ }
+ os << '\n';
+
+ os << " - shape: " << to_string(tensor.shape) << '\n';
+ os << " - lod: ";
+ for (auto &l : tensor.lod) {
+ os << to_string(l) << "; ";
+ }
+ os << "\n";
+ os << " - data: ";
+
+ int dim = std::accumulate(tensor.shape.begin(), tensor.shape.end(), 1,
+ [](int a, int b) { return a * b; });
+ for (int i = 0; i < dim; i++) {
+ os << static_cast(tensor.data.data())[i] << " ";
+ }
+ os << '\n';
+ return os.str();
+}
+
+void PrintTime(int batch_size, int repeat, int num_threads, int tid,
+ double latency) {
+ LOG(INFO) << "batch_size: " << batch_size << ", repeat: " << repeat
+ << ", threads: " << num_threads << ", thread id: " << tid
+ << ", latency: " << latency << "ms";
+}
+
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h
index 1baa64c249f291ec1bc874be5031abe6d4368274..995da11e4a30eca72a91a53d3293aa8b033b012b 100644
--- a/paddle/fluid/inference/api/paddle_inference_api.h
+++ b/paddle/fluid/inference/api/paddle_inference_api.h
@@ -150,6 +150,21 @@ struct TensorRTConfig : public NativeConfig {
int workspace_size{1 << 30};
};
+// NOTE WIP, not stable yet.
+struct AnalysisConfig : public NativeConfig {
+ //
+ enum class IrPassMode {
+ kSystem, // Use system default passes, not customize.
+ kInclude, // Specify the passes in `ir_passes`.
+ kExclude // Specify the disabled passes in `ir_passes`.
+ };
+
+ bool enable_ir_optim = true;
+ IrPassMode ir_mode{IrPassMode::kExclude};
+ // attention lstm fuse works only on some specific models, disable as default.
+ std::vector ir_passes{"attention_lstm_fuse_pass"};
+};
+
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
diff --git a/paddle/fluid/inference/paddle_fluid.map b/paddle/fluid/inference/paddle_fluid.map
index 5203784dc1fcb672eb6a26d9dfd3ffbe02e08038..7e5cae04b81e6ce759b92f6c4b921ecf974e8260 100644
--- a/paddle/fluid/inference/paddle_fluid.map
+++ b/paddle/fluid/inference/paddle_fluid.map
@@ -1,6 +1,7 @@
{
global:
*paddle*;
+ *Pass*;
local:
*;
};
diff --git a/paddle/fluid/operators/auc_op.cc b/paddle/fluid/operators/auc_op.cc
index 5edecd18e673da326ec119cf9a383f24f8045089..dfaa7456f917c1308984b361afed752f96ea6f59 100644
--- a/paddle/fluid/operators/auc_op.cc
+++ b/paddle/fluid/operators/auc_op.cc
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/auc_op.h"
-#include
namespace paddle {
namespace operators {
@@ -36,15 +35,12 @@ class AucOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height.");
- int num_thres = ctx->Attrs().Get("num_thresholds");
+ int num_pred_buckets = ctx->Attrs().Get("num_thresholds") + 1;
ctx->SetOutputDim("AUC", {1});
- ctx->SetOutputDim("TPOut", {num_thres});
- ctx->SetOutputDim("TNOut", {num_thres});
- ctx->SetOutputDim("FPOut", {num_thres});
- ctx->SetOutputDim("FNOut", {num_thres});
-
- ctx->ShareLoD("Predict", /*->*/ "AUC");
+ ctx->SetOutputDim("BatchAUC", {1});
+ ctx->SetOutputDim("StatPosOut", {num_pred_buckets});
+ ctx->SetOutputDim("StatNegOut", {num_pred_buckets});
}
protected:
@@ -66,25 +62,24 @@ class AucOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Label",
"A 2D int tensor indicating the label of the training data. "
"shape: [batch_size, 1]");
- AddInput("TP", "True-Positive value.");
- AddInput("FP", "False-Positive value.");
- AddInput("TN", "True-Negative value.");
- AddInput("FN", "False-Negative value.");
// TODO(typhoonzero): support weight input
+ AddInput("StatPos", "Statistic value when label = 1");
+ AddInput("StatNeg", "Statistic value when label = 0");
+
AddOutput("AUC",
"A scalar representing the "
"current area-under-the-curve.");
- AddOutput("TPOut", "True-Positive value.");
- AddOutput("FPOut", "False-Positive value.");
- AddOutput("TNOut", "True-Negative value.");
- AddOutput("FNOut", "False-Negative value.");
+ AddOutput("BatchAUC", "The AUC for current batch");
+ AddOutput("StatPosOut", "Statistic value when label = 1");
+ AddOutput("StatNegOut", "Statistic value when label = 0");
AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.")
.SetDefault("ROC");
+
AddAttr("num_thresholds",
"The number of thresholds to use when discretizing the"
" roc curve.")
- .SetDefault(200);
+ .SetDefault((2 << 12) - 1);
AddComment(R"DOC(
Area Under The Curve (AUC) Operator.
diff --git a/paddle/fluid/operators/auc_op.h b/paddle/fluid/operators/auc_op.h
index 0a18585edb54a76aff5ae72ecc71e0eebb9f9361..fb0517d70635e090f8c5b59ff9d8420fc34c747b 100644
--- a/paddle/fluid/operators/auc_op.h
+++ b/paddle/fluid/operators/auc_op.h
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+
#include
#include
-#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
@@ -23,106 +23,85 @@ namespace operators {
using Tensor = framework::Tensor;
-template
-using EigenVector = framework::EigenVector;
-
template
class AucKernel : public framework::OpKernel {
public:
- void Compute(const framework::ExecutionContext& ctx) const override {
- auto* predict = ctx.Input("Predict");
- auto* label = ctx.Input("Label");
- auto* auc = ctx.Output("AUC");
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto *predict = ctx.Input("Predict");
+ auto *label = ctx.Input("Label");
+
+ std::string curve = ctx.Attr("curve");
+ int num_thresholds = ctx.Attr("num_thresholds");
+ int num_pred_buckets = num_thresholds + 1;
+
// Only use output var for now, make sure it's persistable and
// not cleaned up for each batch.
- auto* true_positive = ctx.Output("TPOut");
- auto* false_positive = ctx.Output("FPOut");
- auto* true_negative = ctx.Output("TNOut");
- auto* false_negative = ctx.Output("FNOut");
+ auto *auc = ctx.Output("AUC");
+ auto *stat_pos = ctx.Output("StatPosOut");
+ auto *stat_neg = ctx.Output("StatNegOut");
- auto* auc_data = auc->mutable_data(ctx.GetPlace());
+ auto *stat_pos_data = stat_pos->mutable_data(ctx.GetPlace());
+ auto *stat_neg_data = stat_neg->mutable_data(ctx.GetPlace());
+ calcAuc(ctx, label, predict, stat_pos_data, stat_neg_data, num_thresholds,
+ auc);
- std::string curve = ctx.Attr("curve");
- int num_thresholds = ctx.Attr("num_thresholds");
- std::vector thresholds_list;
- thresholds_list.reserve(num_thresholds);
- for (int i = 1; i < num_thresholds - 1; i++) {
- thresholds_list[i] = static_cast(i) / (num_thresholds - 1);
- }
- const double kEpsilon = 1e-7;
- thresholds_list[0] = 0.0f - kEpsilon;
- thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon;
+ auto *batch_auc = ctx.Output("BatchAUC");
+ std::vector stat_pos_batch(num_pred_buckets, 0);
+ std::vector stat_neg_batch(num_pred_buckets, 0);
+ calcAuc(ctx, label, predict, stat_pos_batch.data(), stat_neg_batch.data(),
+ num_thresholds, batch_auc);
+ }
+ private:
+ inline static double trapezoidArea(double X1, double X2, double Y1,
+ double Y2) {
+ return (X1 > X2 ? (X1 - X2) : (X2 - X1)) * (Y1 + Y2) / 2.0;
+ }
+
+ inline static void calcAuc(const framework::ExecutionContext &ctx,
+ const framework::Tensor *label,
+ const framework::Tensor *predict,
+ int64_t *stat_pos, int64_t *stat_neg,
+ int num_thresholds,
+ framework::Tensor *auc_tensor) {
size_t batch_size = predict->dims()[0];
size_t inference_width = predict->dims()[1];
+ const T *inference_data = predict->data();
+ const auto *label_data = label->data();
+
+ auto *auc = auc_tensor->mutable_data(ctx.GetPlace());
- const T* inference_data = predict->data();
- const auto* label_data = label->data();
-
- auto* tp_data = true_positive->mutable_data(ctx.GetPlace());
- auto* fn_data = false_negative->mutable_data(ctx.GetPlace());
- auto* tn_data = true_negative->mutable_data(ctx.GetPlace());
- auto* fp_data = false_positive->mutable_data(ctx.GetPlace());
-
- for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) {
- // calculate TP, FN, TN, FP for current thresh
- int64_t tp = 0, fn = 0, tn = 0, fp = 0;
- for (size_t i = 0; i < batch_size; i++) {
- // NOTE: label_data used as bool, labels > 0 will be treated as true.
- if (label_data[i]) {
- if (inference_data[i * inference_width + 1] >=
- (thresholds_list[idx_thresh])) {
- tp++;
- } else {
- fn++;
- }
- } else {
- if (inference_data[i * inference_width + 1] >=
- (thresholds_list[idx_thresh])) {
- fp++;
- } else {
- tn++;
- }
- }
+ for (size_t i = 0; i < batch_size; i++) {
+ uint32_t binIdx = static_cast(
+ inference_data[i * inference_width + 1] * num_thresholds);
+ if (label_data[i]) {
+ stat_pos[binIdx] += 1.0;
+ } else {
+ stat_neg[binIdx] += 1.0;
}
- // store rates
- tp_data[idx_thresh] += tp;
- fn_data[idx_thresh] += fn;
- tn_data[idx_thresh] += tn;
- fp_data[idx_thresh] += fp;
}
- // epsilon to avoid divide by zero.
- double epsilon = 1e-6;
- // Riemann sum to caculate auc.
- Tensor tp_rate, fp_rate, rec_rate;
- tp_rate.Resize({num_thresholds});
- fp_rate.Resize({num_thresholds});
- rec_rate.Resize({num_thresholds});
- auto* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace());
- auto* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace());
- auto* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace());
- for (int i = 0; i < num_thresholds; i++) {
- tp_rate_data[i] = (static_cast(tp_data[i]) + epsilon) /
- (tp_data[i] + fn_data[i] + epsilon);
- fp_rate_data[i] =
- static_cast(fp_data[i]) / (fp_data[i] + tn_data[i] + epsilon);
- rec_rate_data[i] = (static_cast(tp_data[i]) + epsilon) /
- (tp_data[i] + fp_data[i] + epsilon);
+
+ *auc = 0.0f;
+
+ double totPos = 0.0;
+ double totNeg = 0.0;
+ double totPosPrev = 0.0;
+ double totNegPrev = 0.0;
+
+ int idx = num_thresholds;
+
+ while (idx >= 0) {
+ totPosPrev = totPos;
+ totNegPrev = totNeg;
+ totPos += stat_pos[idx];
+ totNeg += stat_neg[idx];
+ *auc += trapezoidArea(totNeg, totNegPrev, totPos, totPosPrev);
+
+ --idx;
}
- *auc_data = 0.0f;
- if (curve == "ROC") {
- for (int i = 0; i < num_thresholds - 1; i++) {
- auto dx = fp_rate_data[i] - fp_rate_data[i + 1];
- auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f;
- *auc_data = *auc_data + dx * y;
- }
- } else if (curve == "PR") {
- for (int i = 1; i < num_thresholds; i++) {
- auto dx = tp_rate_data[i] - tp_rate_data[i - 1];
- auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f;
- *auc_data = *auc_data + dx * y;
- }
+
+ if (totPos > 0.0 && totNeg > 0.0) {
+ *auc = *auc / totPos / totNeg;
}
}
};
diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc
index 66784f0b5149a7c479a90a407709d993f4a40a8b..31159a02592a2aff75f7ecf5be924989f0f47071 100644
--- a/paddle/fluid/operators/distributed/request_handler_impl.cc
+++ b/paddle/fluid/operators/distributed/request_handler_impl.cc
@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname;
- // Async
- if (!sync_mode_) {
- rpc_server_->Profiler().OneStep();
- try {
- executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
- scope);
- } catch (std::exception& e) {
- LOG(ERROR) << "async: run sub program error " << e.what();
- return false;
- }
- return true;
- }
-
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: recv complete message";
rpc_server_->Complete();
} else {
- VLOG(3) << "sync: received var_name: " << varname;
- rpc_server_->WaitCond(kRequestSend);
- VLOG(3) << "sync: processing received var: " << varname;
-
- if (invar == nullptr) {
- LOG(FATAL) << "sync: Can not find server side var: " << varname;
- return false;
- }
- if (invar->IsType()) {
- std::unique_lock lock(mutex_sparse_vars_);
- sparse_vars_.push_back(invar);
+ // Async
+ if (!sync_mode_) {
+ VLOG(3) << "async process var: " << varname;
+ rpc_server_->Profiler().OneStep();
+ try {
+ executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
+ scope);
+ } catch (std::exception& e) {
+ LOG(ERROR) << "async: run sub program error " << e.what();
+ return false;
+ }
+ return true;
+ } else { // sync
+ rpc_server_->WaitCond(kRequestSend);
+ VLOG(3) << "sync: processing received var: " << varname;
+
+ if (invar == nullptr) {
+ LOG(FATAL) << "sync: Can not find server side var: " << varname;
+ return false;
+ }
+
+ if (invar->IsType()) {
+ std::unique_lock lock(mutex_sparse_vars_);
+ sparse_vars_.push_back(invar);
+ }
}
}
return true;
diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu
index 7c65d6dba7d67b5d31720bae1f4877dd22210138..a0ff6396210c2b3a7f8bd6b9f274b875d7fd4933 100644
--- a/paddle/fluid/operators/fake_quantize_op.cu
+++ b/paddle/fluid/operators/fake_quantize_op.cu
@@ -119,7 +119,8 @@ struct FindRangeAbsMaxFunctor {
const framework::Tensor& last_scale,
const framework::Tensor& iter, const int window_size,
framework::Tensor* scales_arr, framework::Tensor* out_scale) {
- auto& gpu_place = boost::get(ctx.GetPlace());
+ const auto gpu_place = boost::get(ctx.GetPlace());
+
T* scale_arr = scales_arr->mutable_data(gpu_place);
T* out_scale_data = out_scale->mutable_data(gpu_place);
diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc
index fdda01381e117cecffb2a05f8399f3ad82a46339..8e80dc0e641c443923076c31e269689b5bc134a7 100644
--- a/paddle/fluid/operators/flatten_op.cc
+++ b/paddle/fluid/operators/flatten_op.cc
@@ -157,6 +157,116 @@ class FlattenGradOp : public framework::OperatorBase {
}
};
+// FIXME(zcd): flatten2 adds an intermediate output(XShape) based on flatten,
+// the XShape is used to carry the shape and lod of X which will be used in
+// flatten_grad, in this way, the framework can reuse the memory of X
+// immediately the flatten2_op is finished.
+// Considering compatibility issues, we could not fix flatten2_op
+class Flatten2OpInferShape : public FlattenOpInferShape {
+ public:
+ void operator()(framework::InferShapeContext *ctx) const override {
+ FlattenOpInferShape::operator()(ctx);
+ PADDLE_ENFORCE(ctx->HasOutput("XShape"),
+ "Output (XShape) of Flatten op should not be null.");
+ const auto &in_dims = ctx->GetInputDim("X");
+ std::vector xshape_dims(in_dims.size() + 1);
+ xshape_dims[0] = 0;
+ for (int i = 0; i < in_dims.size(); ++i) {
+ xshape_dims[i + 1] = in_dims[i];
+ }
+ ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
+ ctx->ShareLoD("X", "XShape");
+ }
+};
+
+class Flatten2Op : public framework::OperatorBase {
+ public:
+ using OperatorBase::OperatorBase;
+
+ private:
+ void RunImpl(const framework::Scope &scope,
+ const platform::Place &place) const override {
+ auto &axis = Attr("axis");
+ auto in_dims =
+ scope.FindVar(Input("X"))->Get().dims();
+ const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
+
+ framework::AttributeMap attrs;
+ attrs["shape"] = out_dims;
+ attrs["inplace"] = false;
+ // Invoke Reshape Op
+ auto reshape_op = framework::OpRegistry::CreateOp(
+ "reshape2", {{"X", {Input("X")}}, {"Shape", {}}},
+ {{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs);
+ reshape_op->Run(scope, place);
+ }
+};
+
+class Flatten2OpMaker : public FlattenOpMaker {
+ public:
+ void Make() override {
+ FlattenOpMaker::Make();
+ AddOutput("XShape",
+ "XShape is just used to store the shape and lod of X, which will "
+ "be used in FlattenGradOp.")
+ .AsIntermediate();
+ }
+};
+
+class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker {
+ public:
+ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
+
+ std::unique_ptr Apply() const override {
+ auto *grad_op = new framework::OpDesc();
+ grad_op->SetType("flatten2_grad");
+ grad_op->SetInput("XShape", Output("XShape"));
+ grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
+ grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
+ grad_op->SetAttrMap(Attrs());
+ return std::unique_ptr(grad_op);
+ }
+};
+
+class Flatten2GradInferShape : public framework::InferShapeBase {
+ public:
+ void operator()(framework::InferShapeContext *context) const override {
+ PADDLE_ENFORCE(context->HasInput("XShape"),
+ "Input(XShape) shouldn't be null.");
+ PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")),
+ "Input(Out@GRAD) shouldn't be null.");
+ auto xshape_dims = context->GetInputDim("XShape");
+ auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
+ context->SetOutputDim(framework::GradVarName("X"), x_dims);
+ context->ShareLoD("XShape", framework::GradVarName("X"));
+ }
+};
+
+class Flatten2GradOp : public framework::OperatorBase {
+ public:
+ using OperatorBase::OperatorBase;
+
+ private:
+ void RunImpl(const framework::Scope &scope,
+ const platform::Place &place) const override {
+ auto dx_name = Output(framework::GradVarName("X"));
+ auto dout_name = Input(framework::GradVarName("Out"));
+ auto xshape_name = Input("XShape");
+ auto xshape_dims =
+ scope.FindVar(xshape_name)->Get().dims();
+ auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
+
+ framework::AttributeMap attrs;
+ attrs["shape"] = framework::vectorize2int(x_dims);
+ attrs["inplace"] = false;
+
+ auto reshape_op = framework::OpRegistry::CreateOp(
+ "reshape2", {{"X", {dout_name}}, {"Shape", {}}},
+ {{"Out", {dx_name}}, {"XShape", {xshape_name}}}, attrs);
+ reshape_op->Run(scope, place);
+ }
+};
+
} // namespace operators
} // namespace paddle
@@ -167,3 +277,8 @@ REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
ops::FlattenOpInferShape,
paddle::framework::DefaultGradOpDescMaker);
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape);
+
+REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker,
+ ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker);
+REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp,
+ ops::Flatten2GradInferShape);
diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc
index f91236975d0cf0c89a464188bd6ea1b5b01e0f6d..104e160e2d7069ec247cc51e927ce8824f1b69e8 100644
--- a/paddle/fluid/operators/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fusion_lstm_op.cc
@@ -89,12 +89,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
- PADDLE_ENFORCE(!ctx->Attrs().Get("use_peepholes"),
- "Do not support peephole yet.");
- PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
+ auto use_peepholes = ctx->Attrs().Get("use_peepholes");
+ PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size,
"The second dimension of Input(Bias) should be "
- "4 * %d if disable peepholes connection",
- frame_size);
+ "7 * %d if enable peepholes connection or"
+ "4 * %d if disable peepholes",
+ frame_size, frame_size);
framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
@@ -232,16 +232,17 @@ class FuisonLSTMKernel : public framework::OpKernel {
act_cand = act_functor(act_cand_str); \
}
-#define INIT_BASE_INPUT_OUTPUT \
- auto* x = ctx.Input("X"); \
- auto* h0 = ctx.Input("H0"); \
- auto* c0 = ctx.Input("C0"); \
- auto* wx = ctx.Input("WeightX"); \
- auto* wh = ctx.Input("WeightH"); \
- auto* bias = ctx.Input("Bias"); \
- auto* xx = ctx.Output("XX"); \
- auto* hidden_out = ctx.Output("Hidden"); \
- auto* cell_out = ctx.Output("Cell"); \
+#define INIT_BASE_INPUT_OUTPUT \
+ auto* x = ctx.Input("X"); \
+ auto* h0 = ctx.Input("H0"); \
+ auto* c0 = ctx.Input("C0"); \
+ auto* wx = ctx.Input("WeightX"); \
+ auto* wh = ctx.Input("WeightH"); \
+ auto* bias = ctx.Input("Bias"); \
+ auto* xx = ctx.Output("XX"); \
+ auto* hidden_out = ctx.Output("Hidden"); \
+ auto* cell_out = ctx.Output("Cell"); \
+ bool use_peepholes = ctx.Attr("use_peepholes"); \
bool is_reverse = ctx.Attr("is_reverse");
#define INIT_BASE_SIZES \
@@ -266,12 +267,21 @@ class FuisonLSTMKernel : public framework::OpKernel {
const T* x_data = x->data();
const T* h0_data = h0 ? h0->data() : nullptr;
const T* c0_data = c0 ? c0->data() : nullptr;
+ const T* bias_data = bias->data();
+ const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
const T* wx_data = wx->data();
const T* wh_data = wh->data();
+
T* xx_data = xx->mutable_data(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data(ctx.GetPlace());
+ // use local variable
+ framework::DDim check_dims({3, D});
+ Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
+ auto checked_cell_data =
+ checked_cell.mutable_data(check_dims, ctx.GetPlace());
+
auto blas = math::GetBlas(ctx);
math::FCCompute(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data());
@@ -297,46 +307,86 @@ class FuisonLSTMKernel : public framework::OpKernel {
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_c_data = nullptr;
const T* prev_h_data = nullptr;
+
int tstart = 0;
if (h0_data) {
prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D;
} else {
- // W_ch, W_ih, W_fh, W_oh
- act_gate(D3, xx_data + D, xx_data + D);
+ // If step == 0 and there is no initialized hidden state, that is to say
+ // the H0 is zeros. Then W_h * H_t-1 can be skipped
+
+ // ~C_t
act_cand(D, xx_data, xx_data);
- // cell out= input*tilde
+ if (use_peepholes) {
+ // I_t, F_t
+ act_gate(D2, xx_data + D, xx_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, xx_data + D, xx_data + D);
+ }
+ // C_t = I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
+
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
+ blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
+ // O_t
+ act_gate(D, xx_data + D3, xx_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_h_data = hidden_out_data;
prev_c_data = cell_out_data;
- tstart = 1;
+ tstart = 1;
move_step();
}
+
for (int step = tstart; step < seq_len; ++step) {
+ // + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast(1),
prev_h_data, D, wh_data, D4, static_cast(1), xx_data, D4);
- // W_ch, W_ih, W_fh, W_oh
- act_gate(D3, xx_data + D, xx_data + D);
+ // ~C_t
act_cand(D, xx_data, xx_data);
- // a = forget * prev_cell
+ if (use_peepholes) {
+ // + W_ic|W_fc * C_t-1 for peephole connection
+ blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
+ blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
+ blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D);
+ // I_t, F_t
+ act_gate(D2, xx_data + D, xx_data + D);
+ } else {
+ // I_t, F_t, O_t
+ act_gate(D3, xx_data + D, xx_data + D);
+ }
+
+ // F_t * C_t-1
blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2);
-
- // b = input * tilde
+ // I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
-
- // cell out= a+b
+ // C_t = F_t * C_t-1 + I_t * ~C_t
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
+ if (use_peepholes) {
+ // + W_oc * C_t for peephole connection
+ blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
+ blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
+ // O_t
+ act_gate(D, xx_data + D3, xx_data + D3);
+ }
+
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
+ // H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
@@ -344,14 +394,14 @@ class FuisonLSTMKernel : public framework::OpKernel {
prev_c_data = cell_out_data;
move_step();
- }
- }
+ } // for each step in batch
+ } // for each batch
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT
- if (x->lod()[0].size() == 2) {
+ if (x->lod()[0].size() == 2) { // batch size == 1
SeqCompute(ctx);
return;
}
@@ -367,6 +417,8 @@ class FuisonLSTMKernel : public framework::OpKernel {
const T* x_data = x->data();
const T* wx_data = wx->data();
const T* wh_data = wh->data();
+ const T* bias_data = bias->data();
+ const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
auto place = ctx.GetPlace();
T* xx_data = xx->mutable_data(place);
T* batched_input_data = batched_input->mutable_data(place);
@@ -375,6 +427,12 @@ class FuisonLSTMKernel : public framework::OpKernel {
hidden_out->mutable_data(place);
cell_out->mutable_data(place);
+ // use local variable
+ framework::DDim check_dims({3, D});
+ Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
+ auto checked_cell_data =
+ checked_cell.mutable_data(check_dims, ctx.GetPlace());
+
math::LoDTensor2BatchFunctor to_batch;
auto& dev_ctx = ctx.template device_context