diff --git a/.copyright.hook b/.copyright.hook
index dc1b096a0ad28db732b794fa856efed71917c5e8..09afff2072df3384a429d01d06188218ae6e85d1 100644
--- a/.copyright.hook
+++ b/.copyright.hook
@@ -9,7 +9,7 @@ import subprocess
import platform
COPYRIGHT = '''
- Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+ Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/cmake/external/boost.cmake b/cmake/external/boost.cmake
index 137f11da7f2f1c46eebf6590d93402786ef543c9..c70d83b3f4bb24740ed67b4e2f98a3ced26d1648 100644
--- a/cmake/external/boost.cmake
+++ b/cmake/external/boost.cmake
@@ -15,9 +15,9 @@
include(ExternalProject)
set(BOOST_PROJECT "extern_boost")
-set(BOOST_VER "1.66.0")
-set(BOOST_TAR "boost_1_66_0")
-set(BOOST_URL "https://dl.bintray.com/boostorg/release/${BOOST_VER}/source/${BOOST_TAR}.tar.gz")
+set(BOOST_VER "1.41.0")
+set(BOOST_TAR "boost_1_41_0")
+set(BOOST_URL "http://sourceforge.net/projects/boost/files/boost/${BOOST_VER}/${BOOST_TAR}.tar.gz")
set(BOOST_SOURCES_DIR ${THIRD_PARTY_PATH}/boost)
set(BOOST_DOWNLOAD_DIR "${BOOST_SOURCES_DIR}/src/${BOOST_PROJECT}")
set(BOOST_INCLUDE_DIR "${BOOST_DOWNLOAD_DIR}/${BOOST_TAR}" CACHE PATH "boost include directory." FORCE)
diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst
index 875094601a5abb6953b973c5f09f90b739e39752..231ec2d4ba102a5d31c47cbc7a5d484ef17a7f3a 100644
--- a/doc/api/v2/fluid/layers.rst
+++ b/doc/api/v2/fluid/layers.rst
@@ -18,6 +18,11 @@ dynamic_lstm
.. autofunction:: paddle.v2.fluid.layers.dynamic_lstm
:noindex:
+dynamic_lstmp
+-------------
+.. autofunction:: paddle.v2.fluid.layers.dynamic_lstmp
+ :noindex:
+
dynamic_gru
-----------
.. autofunction:: paddle.v2.fluid.layers.dynamic_gru
@@ -529,3 +534,13 @@ sequence_reshape
----------------
.. autofunction:: paddle.v2.fluid.layers.sequence_reshape
:noindex:
+
+row_conv
+--------
+.. autofunction:: paddle.v2.fluid.layers.row_conv
+ :noindex:
+
+multiplex
+---------
+.. autofunction:: paddle.v2.fluid.layers.multiplex
+ :noindex:
diff --git a/doc/api/v2/fluid/nets.rst b/doc/api/v2/fluid/nets.rst
index f6b1cb4ba10659fb336899f08376c265c67290f1..500019bc507f859c4c91de5d322a82eb1e78e2de 100644
--- a/doc/api/v2/fluid/nets.rst
+++ b/doc/api/v2/fluid/nets.rst
@@ -26,8 +26,8 @@ glu
:noindex:
-dot_product_attention
----------------------
-.. autofunction:: paddle.v2.fluid.nets.dot_product_attention
+scaled_dot_product_attention
+----------------------------
+.. autofunction:: paddle.v2.fluid.nets.scaled_dot_product_attention
:noindex:
diff --git a/doc/design/dist_refactor/distributed_architecture.md b/doc/design/dist_refactor/distributed_architecture.md
index 3a741f95866fb6c301ca9097af7916281f2278cf..9368c5780dc922953f38bf0f86d9f797a4a8a6fe 100644
--- a/doc/design/dist_refactor/distributed_architecture.md
+++ b/doc/design/dist_refactor/distributed_architecture.md
@@ -152,12 +152,12 @@ for data in train_reader():
`JobDesc` object describe the distributed job resource specification to run on
Cluster environment.
-
+
`RemoteExecutor.run` sends the `ProgramDesc` and
[TrainingJob](https://github.com/PaddlePaddle/cloud/blob/develop/doc/autoscale/README.md#training-job-resource)
to a server in the cluster which executes `RemoteExecutor.listen`. This server is responsible
-to start the final Kubernetes Jobs to run the different role of `ProgramDesc`.
+to start the final Kubernetes Jobs to run the different role of `ProgramDesc` from `ConfigMap`.
### Placement Algorithm
diff --git a/doc/design/dist_refactor/src/remote_executor.graffle b/doc/design/dist_refactor/src/remote_executor.graffle
index ce2c18fee5687732053c48af9c8c290a994a8090..41b2067311694b56d211a4f32d1b76884eeffd2d 100644
Binary files a/doc/design/dist_refactor/src/remote_executor.graffle and b/doc/design/dist_refactor/src/remote_executor.graffle differ
diff --git a/doc/design/dist_refactor/src/remote_executor.png b/doc/design/dist_refactor/src/remote_executor.png
index 6be4b1841b99efdb59557975485d0387f422308c..744e2fb2e0f1bbe058e991ba7b2a09000965ee79 100644
Binary files a/doc/design/dist_refactor/src/remote_executor.png and b/doc/design/dist_refactor/src/remote_executor.png differ
diff --git a/doc/design/support_new_device.md b/doc/design/support_new_device.md
index 4c5f10e2ecb9ec09b78926ca27552741d02d7cc9..8983df900460127fc130043c52373dab505363ba 100644
--- a/doc/design/support_new_device.md
+++ b/doc/design/support_new_device.md
@@ -2,9 +2,9 @@
## Background
-Deep learning has a high demand for computing resources. New high-performance devices and computing libraries are appearing very frequently. Deep learning frameworks have to integrate these high-performance devices and computing libraries flexibly and efficiently.
+Deep learning has a high demand for computing resources. New high-performance devices and computing libraries are appearing very frequently. Deep learning frameworks have to integrate these high-performance devices and computing libraries in a flexible and efficient manner.
-On one hand, hardware and computing libraries usually do not have a one-to-one correspondence. For example,Intel CPUs support Eigen and MKL computing libraries while Nvidia GPUs support Eigen and cuDNN computing libraries. We have to implement operator specific kernels for each computing library.
+On one hand, hardware and computing libraries usually do not have a one-to-one correspondence. For example, Intel CPUs support Eigen and MKL computing libraries while Nvidia GPUs support Eigen and cuDNN computing libraries. We have to implement operator specific kernels for each computing library.
On the other hand, users usually do not want to care about the low-level hardware and computing libraries when writing a neural network configuration. In Fluid, `Layer` is exposed in `Python`, and `Operator` is exposed in `C++`. Both `Layer` and `Operator` are hardware independent.
@@ -17,7 +17,7 @@ For a general overview of fluid, please refer to the [overview doc](https://gith
There are mainly three parts that we have to consider while integrating a new device/library:
-- Place and DeviceContext: indicates the device id and manages hardware resources
+- Place and DeviceContext: indicate the device id and manage hardware resources
- Memory and Tensor: malloc/free data on certain device
@@ -25,10 +25,10 @@ There are mainly three parts that we have to consider while integrating a new de
### Place and DeviceContext
-Please remind that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.
+Please note that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.
#### Place
-Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent the device memory where data is located. If we add another device, we have to add corresponding `DevicePlace`.
+Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent the device memory where data is located. If we add another device, we have to add the corresponding `DevicePlace`.
```
| CPUPlace
@@ -144,7 +144,7 @@ class Tensor {
};
```
-`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configure its shape, and then call `mutuable_data` to allocate the actual memory.
+`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configurate its shape, and then call `mutuable_data` to allocate the actual memory.
```cpp
paddle::framework::Tensor t;
@@ -163,7 +163,7 @@ Fluid implements computing units based on different DeviceContexts. Some computi
Let's take [MaxOutFunctor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/math/maxouting.h#L27) as an example:
-The interface is defined in header file.
+The interface is defined in the header file.
```
template
@@ -174,7 +174,7 @@ class MaxOutFunctor {
};
```
-CPU implemention is in .cc file
+CPU implementation is in .cc file
```
template
@@ -188,7 +188,7 @@ class MaxOutFunctor {
};
```
-CUDA implemention is in .cu file
+CUDA implementation is in .cu file
```
template
@@ -203,9 +203,9 @@ class MaxOutFunctor {
```
-We get computing handle from a concrete DeviceContext, and make compution on tensors.
+We first obtain the computing handle from a concrete DeviceContext and then compute on tensors.
-The implemention of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.
+The implementation of `OpKernel` is similar to math functors, the extra thing we need to do is to register the OpKernel in a global map.
Fluid provides different register interfaces in op_registry.h
@@ -231,7 +231,7 @@ REGISTER_OP_CUDA_KERNEL(
## Advanced topics: How to switch between different Device/Library
-Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.
+Generally, we will implement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not suitable on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run on GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.
For more details, please refer to following docs:
diff --git a/doc/getstarted/build_and_install/pip_install_cn.rst b/doc/getstarted/build_and_install/pip_install_cn.rst
index 0c741e936b46eda5e7165e4ee54b545b14a28a19..8e4165da6b8135d083766c650f1092158f9d01c2 100644
--- a/doc/getstarted/build_and_install/pip_install_cn.rst
+++ b/doc/getstarted/build_and_install/pip_install_cn.rst
@@ -39,6 +39,7 @@ PaddlePaddle可以使用常用的Python包管理工具
"cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无"
+ "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "暂无"
"cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
diff --git a/doc/getstarted/build_and_install/pip_install_en.rst b/doc/getstarted/build_and_install/pip_install_en.rst
index 285ed09805b09790beaef014f6813c227aff33ac..c1e806c0fe5f03139c0dff985f9ae0856eaa2e98 100644
--- a/doc/getstarted/build_and_install/pip_install_en.rst
+++ b/doc/getstarted/build_and_install/pip_install_en.rst
@@ -42,6 +42,7 @@ If the links below shows up the login form, just click "Log in as guest" to star
"cpu_avx_mkl", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cpu_avx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available"
+ "cpu_noavx_openblas", "`paddlepaddle-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "Not Available"
"cuda7.5_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cuda8.0_cudnn5_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
"cuda8.0_cudnn7_avx_mkl", "`paddlepaddle_gpu-0.11.0-cp27-cp27mu-linux_x86_64.whl `_", "`paddlepaddle_gpu-0.11.0-cp27-cp27m-linux_x86_64.whl `_", "`paddle.tgz `_"
diff --git a/doc/howto/usage/cluster/fluid_cluster_train_en.md b/doc/howto/usage/cluster/fluid_cluster_train_en.md
index 11904a6f71bb6ce37417aeffb8e408ec65961b12..ae825d9a517c7e9005d4e32f8f34b3f6a79be0c9 100644
--- a/doc/howto/usage/cluster/fluid_cluster_train_en.md
+++ b/doc/howto/usage/cluster/fluid_cluster_train_en.md
@@ -16,6 +16,12 @@ PaddlePaddle must be installed on all nodes. If you have GPU cards on your nodes
PaddlePaddle build and installation guide can be found [here](http://www.paddlepaddle.org/docs/develop/documentation/en/getstarted/build_and_install/index_en.html).
+In addition to above, the `cmake` command should be run with the option `WITH_DISTRIBUTE` set to on. An example bare minimum `cmake` command would look as follows:
+
+``` bash
+cmake .. -DWITH_DOC=OFF -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_SWIG_PY=ON -DWITH_PYTHON=ON
+```
+
### Update the training script
#### Non-cluster training script
@@ -119,7 +125,14 @@ for pass_id in range(100):
### E2E demo
-Please find the complete demo from [here](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py). In parameter server node run the following in the command line:
+Please find the complete demo from [here](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/fluid/tests/book_distribute/notest_dist_fit_a_line.py).
+First `cd` into the folder that contains the `python` files. In this case:
+
+```bash
+cd /paddle/python/paddle/v2/fluid/tests/book_distribute
+```
+
+In parameter server node run the following in the command line:
``` bash
PSERVERS=192.168.1.2:6174 SERVER_ENDPOINT=192.168.1.2:6174 TRAINING_ROLE=PSERVER python notest_dist_fit_a_line.py
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index b83007ac3bc6ed8713ca65fddabccfd292a2732f..318661af8bd04880577222fdc82cc1b6e79a457f 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -74,7 +74,10 @@ cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
-cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog lod_rank_table)
+cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
+
+cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
+framework_proto backward glog lod_rank_table profiler feed_fetch_method)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
@@ -95,3 +98,5 @@ if(NOT WITH_C_API AND WITH_FLUID)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/framework.pb.h DESTINATION include/paddle/framework)
install(FILES details/cow_ptr.h details/op_registry.h DESTINATION include/paddle/framework/details)
endif()
+
+cc_test(channel_test SRCS channel_test.cc)
diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h
new file mode 100644
index 0000000000000000000000000000000000000000..70ecccc1a1078374f3190b3956103ed8000c4fc5
--- /dev/null
+++ b/paddle/framework/channel.h
@@ -0,0 +1,64 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include // for size_t
+
+namespace paddle {
+namespace framework {
+
+// Channel is the abstract class of buffered and un-buffered channels.
+template
+class Channel {
+ public:
+ virtual void Send(T*) = 0;
+ virtual void Receive(T*) = 0;
+ virtual size_t Cap() = 0;
+
+ // Don't delete channels; instead, call Channel::Close.
+ protected:
+ virtual ~Channel() {}
+};
+
+// Forward declaration of channel implementations.
+namespace details {
+template
+class Buffered;
+template
+class UnBuffered;
+} // namespace details
+
+template
+Channel* MakeChannel(size_t buffer_size) {
+ if (buffer_size > 0) {
+ return new details::Buffered(buffer_size);
+ }
+ return new details::UnBuffered();
+}
+
+template
+void CloseChannel(Channel* ch) {
+ if (ch->Cap() > 0) {
+ delete dynamic_cast*>(ch);
+ } else {
+ delete dynamic_cast*>(ch);
+ }
+}
+
+} // namespace framework
+} // namespace paddle
+
+#include "paddle/framework/details/buffered_channel.h"
+#include "paddle/framework/details/unbuffered_channel.h"
diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc
new file mode 100644
index 0000000000000000000000000000000000000000..9efc0172658c800d14102531332dbef68fa392f4
--- /dev/null
+++ b/paddle/framework/channel_test.cc
@@ -0,0 +1,26 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/channel.h"
+
+#include "gtest/gtest.h"
+
+TEST(Channel, MakeAndClose) {
+ using paddle::framework::Channel;
+ using paddle::framework::MakeChannel;
+ using paddle::framework::CloseChannel;
+
+ Channel* ch = MakeChannel(10);
+ CloseChannel(ch);
+}
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index 6a372ac32e48131eed28e2d42125feb5b92a11c7..98eb3e857d1943e71f1d41f24ecbedbe09e85b7b 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -79,5 +79,33 @@ inline void VisitDataType(proto::DataType type, Visitor visitor) {
}
}
+inline std::string DataTypeToString(const proto::DataType type) {
+ using namespace paddle::framework::proto;
+ switch (type) {
+ case DataType::FP16:
+ return "float16";
+ case DataType::FP32:
+ return "float32";
+ case DataType::FP64:
+ return "float64";
+ case DataType::INT16:
+ return "int16";
+ case DataType::INT32:
+ return "int32";
+ case DataType::INT64:
+ return "int64";
+ case DataType::BOOL:
+ return "bool";
+ default:
+ PADDLE_THROW("Not support type %d", type);
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& out,
+ const proto::DataType& type) {
+ out << DataTypeToString(type);
+ return out;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h
new file mode 100644
index 0000000000000000000000000000000000000000..572e29d44a3baec84a029d87f9b0874784aa761b
--- /dev/null
+++ b/paddle/framework/details/buffered_channel.h
@@ -0,0 +1,82 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include
+
+#include "paddle/framework/channel.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+template
+class Buffered : public paddle::framework::Channel {
+ friend Channel* paddle::framework::MakeChannel(size_t);
+ friend void paddle::framework::CloseChannel(Channel*);
+
+ public:
+ virtual void Send(T*);
+ virtual void Receive(T*);
+ virtual size_t Cap() { return cap_; }
+
+ private:
+ size_t cap_;
+ std::mutex mu_;
+ std::condition_variable empty_cond_var_;
+ std::condition_variable full_cond_var_;
+ std::deque channel_;
+
+ Buffered(size_t cap) : cap_(cap) {}
+ virtual ~Buffered();
+
+ void NotifyAllSenders(std::unique_lock*);
+};
+
+template
+void Buffered::Send(T* item) {
+ std::unique_lock lock(mu_);
+ full_cond_var_.wait(lock, [this]() { return channel_.size() < cap_; });
+ channel_.push_back(std::move(*item));
+ lock.unlock();
+ empty_cond_var_.notify_one();
+}
+
+template
+void Buffered::Receive(T* item) {
+ std::unique_lock lock(mu_);
+ empty_cond_var_.wait(lock, [this]() { return !channel_.empty(); });
+ *item = std::move(channel_.front());
+ channel_.pop_front();
+ NotifyAllSenders(&lock);
+}
+
+template
+Buffered::~Buffered() {
+ std::unique_lock lock(mu_);
+ channel_.clear();
+ NotifyAllSenders(&lock);
+}
+
+template
+void Buffered::NotifyAllSenders(std::unique_lock* lock) {
+ lock->unlock();
+ full_cond_var_.notify_one();
+}
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h
new file mode 100644
index 0000000000000000000000000000000000000000..7ecced1fba88fea781fc342091bc71e5aa496d3a
--- /dev/null
+++ b/paddle/framework/details/unbuffered_channel.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include
+
+#include "paddle/framework/channel.h"
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+template
+class UnBuffered : public paddle::framework::Channel {
+ friend Channel* paddle::framework::MakeChannel(size_t);
+ friend void paddle::framework::CloseChannel(Channel*);
+
+ public:
+ virtual void Send(T*);
+ virtual void Receive(T*);
+ virtual size_t Cap() { return 0; }
+
+ private:
+ UnBuffered() {}
+ virtual ~UnBuffered();
+};
+
+template
+void UnBuffered::Send(T* channel_element) {}
+
+template
+void UnBuffered::Receive(T*) {}
+
+template
+UnBuffered::~UnBuffered() {}
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index bd58c0a7f8161f6b45f2b500f3685e4028d97e96..cbf3ec75265fa74aaacffee684b7b7d5f73b7c02 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -17,11 +17,13 @@ limitations under the License. */
#include
#include "gflags/gflags.h"
+#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
+#include "paddle/platform/profiler.h"
DECLARE_bool(do_memory_benchmark);
DEFINE_bool(check_nan_inf, false,
@@ -117,6 +119,10 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(4) << op->DebugStringEx(local_scope);
+
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ platform::RecordEvent record_event(op->Type(), pool.Get(place_));
+
op->Run(*local_scope, place_);
VLOG(3) << op->DebugStringEx(local_scope);
if (FLAGS_do_memory_benchmark) {
@@ -144,5 +150,164 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
}
}
+// Check whether the block already has feed operators and feed_holder.
+// Return false if the block does not have any feed operators.
+// If some feed operators have been prepended to the block, check that
+// the info contained in these feed operators matches the feed_targets
+// and feed_holder_name. Raise exception when any mismatch is found.
+// Return true if the block has feed operators and holder of matching info.
+static bool has_feed_operators(
+ BlockDesc* block, std::map& feed_targets,
+ const std::string& feed_holder_name) {
+ size_t feed_count = 0;
+ for (auto* op : block->AllOps()) {
+ if (op->Type() == kFeedOpType) {
+ feed_count++;
+ PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
+ "Input to feed op should be '%s'", feed_holder_name);
+ std::string feed_target_name = op->Output("Out")[0];
+ PADDLE_ENFORCE(
+ feed_targets.find(feed_target_name) != feed_targets.end(),
+ "Feed operator output name '%s' cannot be found in 'feed_targets'",
+ feed_target_name);
+ }
+ }
+
+ if (feed_count > 0) {
+ PADDLE_ENFORCE_EQ(
+ feed_count, feed_targets.size(),
+ "The number of feed operators should match 'feed_targets'");
+
+ // When feed operator are present, so should be feed_holder
+ auto var = block->FindVar(feed_holder_name);
+ PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
+ feed_holder_name);
+ PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FEED_MINIBATCH,
+ "'%s' variable should be 'FEED_MINIBATCH' type",
+ feed_holder_name);
+ }
+
+ return feed_count > 0;
+}
+
+// Check whether the block already has fetch operators and fetch_holder.
+// Return false if the block does not have any fetch operators.
+// If some fetch operators have been appended to the block, check that
+// the info contained in these fetch operators matches the fetch_targets
+// and fetch_holder_name. Raise exception when any mismatch is found.
+// Return true if the block has fetch operators and holder of matching info.
+static bool has_fetch_operators(
+ BlockDesc* block, std::map& fetch_targets,
+ const std::string& fetch_holder_name) {
+ size_t fetch_count = 0;
+ for (auto* op : block->AllOps()) {
+ if (op->Type() == kFetchOpType) {
+ fetch_count++;
+ PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
+ "Output of fetch op should be '%s'", fetch_holder_name);
+ std::string fetch_target_name = op->Input("X")[0];
+ PADDLE_ENFORCE(
+ fetch_targets.find(fetch_target_name) != fetch_targets.end(),
+ "Fetch operator input name '%s' cannot be found in 'fetch_targets'",
+ fetch_target_name);
+ }
+ }
+
+ if (fetch_count > 0) {
+ PADDLE_ENFORCE_EQ(
+ fetch_count, fetch_targets.size(),
+ "The number of fetch operators should match 'fetch_targets'");
+
+ // When fetch operator are present, so should be fetch_holder
+ auto var = block->FindVar(fetch_holder_name);
+ PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
+ fetch_holder_name);
+ PADDLE_ENFORCE_EQ(var->GetType(), proto::VarDesc::FETCH_LIST,
+ "'%s' variable should be 'FETCH_LIST' type",
+ fetch_holder_name);
+ }
+
+ return fetch_count > 0;
+}
+
+void Executor::Run(const ProgramDesc& program, Scope* scope,
+ std::map& feed_targets,
+ std::map& fetch_targets,
+ const std::string& feed_holder_name,
+ const std::string& fetch_holder_name) {
+ auto* copy_program = new ProgramDesc(program);
+ auto* global_block = copy_program->MutableBlock(0);
+
+ if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
+ // create feed_holder variable
+ auto* feed_holder = global_block->Var(feed_holder_name);
+ feed_holder->SetType(proto::VarDesc::FEED_MINIBATCH);
+ feed_holder->SetPersistable(true);
+
+ int i = 0;
+ for (auto& feed_target : feed_targets) {
+ std::string var_name = feed_target.first;
+ VLOG(3) << "feed target's name: " << var_name;
+
+ // prepend feed op
+ auto* op = global_block->PrependOp();
+ op->SetType(kFeedOpType);
+ op->SetInput("X", {feed_holder_name});
+ op->SetOutput("Out", {var_name});
+ op->SetAttr("col", {static_cast(i)});
+ op->CheckAttrs();
+
+ i++;
+ }
+ }
+
+ // map the data of feed_targets to feed_holder
+ for (auto* op : global_block->AllOps()) {
+ if (op->Type() == kFeedOpType) {
+ std::string feed_target_name = op->Output("Out")[0];
+ int idx = boost::get(op->GetAttr("col"));
+ SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
+ idx);
+ }
+ }
+
+ if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
+ // create fetch_holder variable
+ auto* fetch_holder = global_block->Var(fetch_holder_name);
+ fetch_holder->SetType(proto::VarDesc::FETCH_LIST);
+ fetch_holder->SetPersistable(true);
+
+ int i = 0;
+ for (auto& fetch_target : fetch_targets) {
+ std::string var_name = fetch_target.first;
+ VLOG(3) << "fetch target's name: " << var_name;
+
+ // append fetch op
+ auto* op = global_block->AppendOp();
+ op->SetType(kFetchOpType);
+ op->SetInput("X", {var_name});
+ op->SetOutput("Out", {fetch_holder_name});
+ op->SetAttr("col", {static_cast(i)});
+ op->CheckAttrs();
+
+ i++;
+ }
+ }
+
+ Run(*copy_program, scope, 0, true, true);
+
+ // obtain the data of fetch_targets from fetch_holder
+ for (auto* op : global_block->AllOps()) {
+ if (op->Type() == kFetchOpType) {
+ std::string fetch_target_name = op->Input("X")[0];
+ int idx = boost::get(op->GetAttr("col"));
+ *fetch_targets[fetch_target_name] =
+ GetFetchVariable(*scope, fetch_holder_name, idx);
+ }
+ }
+
+ delete copy_program;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h
index d869e18901b82959a40cc296aa0844c20ea63ac1..035ff48a52bd2fc4b1a46b48b1fbf1fbcb2ac70b 100644
--- a/paddle/framework/executor.h
+++ b/paddle/framework/executor.h
@@ -41,6 +41,12 @@ class Executor {
void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);
+ void Run(const ProgramDesc& program, Scope* scope,
+ std::map& feed_targets,
+ std::map& fetch_targets,
+ const std::string& feed_holder_name = "feed",
+ const std::string& fetch_holder_name = "fetch");
+
private:
const platform::Place place_;
};
diff --git a/paddle/framework/feed_fetch_method.cc b/paddle/framework/feed_fetch_method.cc
new file mode 100644
index 0000000000000000000000000000000000000000..21201b675519e34b11e9f1f3a6f2a135c06d63a7
--- /dev/null
+++ b/paddle/framework/feed_fetch_method.cc
@@ -0,0 +1,56 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/feed_fetch_method.h"
+#include "glog/logging.h"
+#include "paddle/framework/variable.h"
+
+namespace paddle {
+namespace framework {
+
+void SetFeedVariable(Scope* scope, const LoDTensor& input,
+ const std::string& var_name, size_t index) {
+ // If var_name Variable is not found in GlobalScope, a new variable will
+ // be created.
+ VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
+ Variable* g_feed_value = scope->Var(var_name);
+ auto& feed_inputs =
+ *(g_feed_value->GetMutable>());
+ if (index >= feed_inputs.size()) {
+ feed_inputs.resize(index + 1);
+ }
+ // shared data with input tensor
+ feed_inputs[index].ShareDataWith(input);
+ // set lod
+ feed_inputs[index].set_lod(input.lod());
+}
+
+LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
+ size_t index) {
+ // Since we want to fetch LodTensor from a variable, the variable must
+ // be created alreadly.
+ Variable* g_fetch_value = scope.FindVar(var_name);
+ PADDLE_ENFORCE(g_fetch_value->IsType(),
+ "Only %s can be invoked by GetFetchVariable",
+ typeid(FeedFetchList).name());
+ auto& fetch_outputs = *g_fetch_value->GetMutable();
+ auto& tensor = fetch_outputs[index];
+ VLOG(3) << "Fetch " << var_name << " with index " << index
+ << " shape= " << tensor.dims();
+ PADDLE_ENFORCE_LT(index, fetch_outputs.size());
+ return tensor;
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/feed_fetch_method.h b/paddle/framework/feed_fetch_method.h
index 7feacb1e24708411e7fbb610f9909447cba9e291..b71945fcc8834d2e5fe21151e1e88788b4acd5c1 100644
--- a/paddle/framework/feed_fetch_method.h
+++ b/paddle/framework/feed_fetch_method.h
@@ -13,46 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-#include "glog/logging.h"
+
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/scope.h"
-#include "paddle/framework/variable.h"
namespace paddle {
namespace framework {
void SetFeedVariable(Scope* scope, const LoDTensor& input,
- const std::string& var_name, size_t index) {
- // If var_name Variable is not found in GlobalScope, a new variable will
- // be created.
- VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
- Variable* g_feed_value = scope->Var(var_name);
- auto& feed_inputs =
- *(g_feed_value->GetMutable>());
- if (index >= feed_inputs.size()) {
- feed_inputs.resize(index + 1);
- }
- // shared data with input tensor
- feed_inputs[index].ShareDataWith(input);
- // set lod
- feed_inputs[index].set_lod(input.lod());
-}
+ const std::string& var_name, size_t index);
LoDTensor& GetFetchVariable(const Scope& scope, const std::string& var_name,
- size_t index) {
- // Since we want to fetch LodTensor from a variable, the variable must
- // be created alreadly.
- Variable* g_fetch_value = scope.FindVar(var_name);
- PADDLE_ENFORCE(g_fetch_value->IsType(),
- "Only %s can be invoked by GetFetchVariable",
- typeid(FeedFetchList).name());
- auto& fetch_outputs = *g_fetch_value->GetMutable();
- auto& tensor = fetch_outputs[index];
- VLOG(3) << "Fetch " << var_name << " with index " << index
- << " shape= " << tensor.dims();
- PADDLE_ENFORCE_LT(index, fetch_outputs.size());
- return tensor;
-}
+ size_t index);
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc
index 649afeee8a846b0579545f2edff77e9dbe3b4dd8..cb23bbde01493d1a3b5845e77d6160a75f409c7a 100644
--- a/paddle/framework/op_kernel_type_test.cc
+++ b/paddle/framework/op_kernel_type_test.cc
@@ -26,9 +26,9 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
- ASSERT_EQ(
- paddle::framework::KernelTypeToString(op_kernel_type),
- "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
+ ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
+ "data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type["
+ "CUDNN]");
}
TEST(OpKernelType, Hash) {
diff --git a/paddle/framework/threadpool.cc b/paddle/framework/threadpool.cc
index 109a7e7dc440d91e8223f2c0924f489f54a06f64..b2f5ae4a96593fde1623dd10d3b63c984ae228db 100644
--- a/paddle/framework/threadpool.cc
+++ b/paddle/framework/threadpool.cc
@@ -1,24 +1,93 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace framework {
-std::unique_ptr ThreadPool::threadpool(nullptr);
-std::once_flag ThreadPool::init_flag;
+std::unique_ptr ThreadPool::threadpool_(nullptr);
+std::once_flag ThreadPool::init_flag_;
+
+ThreadPool* ThreadPool::GetInstance() {
+ std::call_once(init_flag_, &ThreadPool::Init);
+ return threadpool_.get();
+}
+
+void ThreadPool::Init() {
+ if (threadpool_.get() == nullptr) {
+ // TODO(Yancey1989): specify the max threads number
+ int num_threads = std::thread::hardware_concurrency();
+ PADDLE_ENFORCE_GT(num_threads, 0);
+ threadpool_.reset(new ThreadPool(num_threads));
+ }
+}
+
+ThreadPool::ThreadPool(int num_threads)
+ : total_threads_(num_threads), idle_threads_(num_threads), running_(true) {
+ threads_.resize(num_threads);
+ for (auto& thread : threads_) {
+ // TODO(Yancey1989): binding the thread on the specify CPU number
+ thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
+ }
+}
+
+ThreadPool::~ThreadPool() {
+ {
+ // notify all threads to stop running
+ running_ = false;
+ scheduled_.notify_all();
+ }
+
+ for (auto& t : threads_) {
+ t->join();
+ t.reset(nullptr);
+ }
+}
+
+void ThreadPool::Wait() {
+ std::unique_lock lock(mutex_);
+ completed_.wait(lock, [=] { return Done() == true; });
+}
+
+void ThreadPool::TaskLoop() {
+ while (running_) {
+ std::unique_lock lock(mutex_);
+ scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
+
+ if (!running_) {
+ break;
+ }
+ // pop a task from the task queue
+ auto task = std::move(tasks_.front());
+ tasks_.pop();
+
+ --idle_threads_;
+ lock.unlock();
+
+ // run the task
+ task();
+
+ {
+ std::unique_lock lock(mutex_);
+ ++idle_threads_;
+ if (Done()) {
+ completed_.notify_all();
+ }
+ }
+ }
+}
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h
index 3ac345851c38557f82698786dd3bc8e1202a4256..8912b1a43a26f9df662d3b5ddf68bfb2b87f4a20 100644
--- a/paddle/framework/threadpool.h
+++ b/paddle/framework/threadpool.h
@@ -20,52 +20,36 @@ limitations under the License. */
#include
#include
#include
+#include
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
+// ThreadPool maintains a queue of tasks, and runs them using a fixed
+// number of threads.
class ThreadPool {
public:
typedef std::packaged_task Task;
- /**
- * @brief Get a instance of threadpool, the thread number will
- * be specified as the number of hardware thread contexts
- */
- static ThreadPool* GetInstance() {
- std::call_once(init_flag, &ThreadPool::Init);
- return threadpool.get();
- }
+ // Returns the singleton of ThreadPool.
+ static ThreadPool* GetInstance();
- ~ThreadPool() {
- {
- // notify all threads to stop running
- running_ = false;
- scheduled_.notify_all();
- }
-
- for (auto& t : threads_) {
- t->join();
- t.reset(nullptr);
- }
- }
+ ~ThreadPool();
- int GetNumThreads() const { return num_threads_; }
+ // Returns the number of threads created by the constructor.
+ size_t Threads() const { return total_threads_; }
- int GetAvailable() {
+ // Returns the number of currently idle threads.
+ size_t IdleThreads() {
std::unique_lock lock(mutex_);
- return available_;
+ return idle_threads_;
}
- /**
- * @brief Push a function to the queue, and will be scheduled and
- * executed if a thread is available.
- * @param[in] Task, will be pushed to the task queue.
- * @return std::future, we could wait for the task finished by
- * f.wait().
- */
+ // Run pushes a function to the task queue and returns a std::future
+ // object. To wait for the completion of the task, call
+ // std::future::wait().
template
std::future Run(Callback fn) {
std::unique_lock lock(mutex_);
@@ -77,84 +61,40 @@ class ThreadPool {
return f;
}
- /**
- * @brief Wait until all the tasks are completed.
- */
- void Wait() {
- std::unique_lock lock(mutex_);
- completed_.wait(lock, [=] { return Done() == true; });
- }
+ // Wait until all the tasks are completed.
+ void Wait();
private:
DISABLE_COPY_AND_ASSIGN(ThreadPool);
- explicit ThreadPool(int num_threads)
- : num_threads_(num_threads), available_(num_threads), running_(true) {
- threads_.resize(num_threads);
- for (auto& thread : threads_) {
- // TODO(Yancey1989): binding the thread on the specify CPU number
- thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
- }
- }
+ explicit ThreadPool(int num_threads);
- /**
- * @brief If the task queue is empty and avaialbe
- * is equal to the number of threads, means that
- * all tasks are completed.
- *
- * Note: this function is not thread-safe.
- *
- * @return true if all tasks are completed.
- */
- bool Done() { return tasks_.empty() && available_ == num_threads_; }
-
- void TaskLoop() {
- while (running_) {
- std::unique_lock lock(mutex_);
- scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
-
- if (!running_) {
- break;
- }
- // pop a task from the task queue
- auto task = std::move(tasks_.front());
- tasks_.pop();
-
- --available_;
- lock.unlock();
-
- // run the task
- task();
-
- {
- std::unique_lock lock(mutex_);
- ++available_;
- if (Done()) {
- completed_.notify_all();
- }
- }
- }
- }
+ // If the task queue is empty and avaialbe is equal to the number of
+ // threads, means that all tasks are completed. Note: this function
+ // is not thread-safe. Returns true if all tasks are completed.
+ // Note: don't delete the data member total_threads_ and use
+ // threads_.size() instead; because you'd need to lock the mutex
+ // before accessing threads_.
+ bool Done() { return tasks_.empty() && idle_threads_ == total_threads_; }
- static void Init() {
- if (threadpool.get() == nullptr) {
- // TODO(Yancey1989): specify the max threads number
- int num_threads = std::thread::hardware_concurrency();
- PADDLE_ENFORCE_GT(num_threads, 0);
- threadpool.reset(new ThreadPool(num_threads));
- }
- }
+ // The constructor starts threads to run TaskLoop, which retrieves
+ // and runs tasks from the queue.
+ void TaskLoop();
+
+ // Init is called by GetInstance.
+ static void Init();
private:
- static std::unique_ptr threadpool;
- static std::once_flag init_flag;
+ static std::unique_ptr threadpool_;
+ static std::once_flag init_flag_;
- int num_threads_;
- int available_;
- bool running_;
- std::queue tasks_;
std::vector> threads_;
+ const size_t total_threads_;
+ size_t idle_threads_;
+
+ std::queue tasks_;
std::mutex mutex_;
+ bool running_;
std::condition_variable scheduled_;
std::condition_variable completed_;
};
diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc
index 50b6238cd8786be9d8cf2d5f821daadea12bd208..3fbfe7efc867144dbd0dd2613c824c6a3c41b7d8 100644
--- a/paddle/framework/threadpool_test.cc
+++ b/paddle/framework/threadpool_test.cc
@@ -22,11 +22,7 @@ namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) {
std::vector> fs;
for (int i = 0; i < cnt; ++i) {
- auto f = pool->Run([&sum]() { sum.fetch_add(1); });
- fs.push_back(std::move(f));
- }
- for (auto& f : fs) {
- f.wait();
+ fs.push_back(framework::Async([&sum]() { sum.fetch_add(1); }));
}
}
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index ba83667ebc9a89c37f77a7f71e6df90b54723cc0..aab02f16849582db4b41087046b810463a855e1a 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -991,8 +991,10 @@ TEST(Layer, SequenceLastInstanceLayer) {
"seqlastins",
"non-seq",
-1); // hasSubseq seqlastins to non-seq
- testDegradeLayer(
- true, "seqlastins", "seq", -1); // hasSubseq seqlastins to seq
+ testDegradeLayer(true,
+ "seqlastins",
+ "seq",
+ -1); // hasSubseq seqlastins to seq
}
TEST(Layer, AverageLayer) {
@@ -1001,8 +1003,10 @@ TEST(Layer, AverageLayer) {
"average",
"non-seq",
5); // seq average to a shorten seq, stride window = 5
- testDegradeLayer(
- true, "average", "non-seq", -1); // hasSubseq average to non-seq
+ testDegradeLayer(true,
+ "average",
+ "non-seq",
+ -1); // hasSubseq average to non-seq
testDegradeLayer(true, "average", "seq", -1); // hasSubseq average to seq
}
@@ -1287,8 +1291,9 @@ TEST(Layer, PoolLayer) {
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
- testPoolLayer2(
- "cudnn-avg-incl-pad-pool", /* trans= */ false, /* useGpu= */ true);
+ testPoolLayer2("cudnn-avg-incl-pad-pool",
+ /* trans= */ false,
+ /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif
}
@@ -2431,18 +2436,21 @@ TEST(Layer, test3DDeConvLayer) {
}
TEST(Layer, ScaleShiftLayer) {
- const size_t batchSize = 16;
- const size_t size = 32;
- TestConfig config;
- config.layerConfig.set_type("scale_shift");
- config.layerConfig.set_size(size);
- config.biasSize = 1;
- config.inputDefs.push_back(
- {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
- config.layerConfig.add_inputs();
- for (auto useGpu : {false, true}) {
- testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
- }
+ // FIXME: Disable ScaleShiftLayer because it is not stable.
+ // https://github.com/PaddlePaddle/Paddle/issues/7781
+ return;
+ // const size_t batchSize = 16;
+ // const size_t size = 32;
+ // TestConfig config;
+ // config.layerConfig.set_type("scale_shift");
+ // config.layerConfig.set_size(size);
+ // config.biasSize = 1;
+ // config.inputDefs.push_back(
+ // {INPUT_DATA, "input", /* dim= */ size, /* paraSize= */ 1});
+ // config.layerConfig.add_inputs();
+ // for (auto useGpu : {false, true}) {
+ // testLayerGrad(config, "scale_shift", batchSize, false, useGpu, false);
+ // }
}
TEST(Layer, ScaleSubRegionLayer) {
diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc
index 49001778808173b82865a4b6632a6b175ef96242..b43c359ed1787143403336e8c1cb4c7f85b1d7a2 100644
--- a/paddle/inference/inference.cc
+++ b/paddle/inference/inference.cc
@@ -15,18 +15,13 @@ limitations under the License. */
#include "inference.h"
#include
#include "paddle/framework/executor.h"
-#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/init.h"
#include "paddle/framework/scope.h"
-#ifdef PADDLE_USE_PTOOLS
-#include "chooseser.h"
-#endif
-
namespace paddle {
void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
- std::string model_filename = dirname + "/__model__.dat";
+ std::string model_filename = dirname + "/__model__";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
@@ -52,39 +47,15 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
}
}
-void InferenceEngine::LoadInferenceModel(
- const std::string& dirname,
- const std::vector& feed_var_names,
- const std::vector& fetch_var_names) {
- std::string model_filename = dirname + "/__model__.dat";
- LOG(INFO) << "loading model from " << model_filename;
- std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
- std::string program_desc_str;
- inputfs.seekg(0, std::ios::end);
- program_desc_str.resize(inputfs.tellg());
- inputfs.seekg(0, std::ios::beg);
- LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
- inputfs.read(&program_desc_str[0], program_desc_str.size());
- inputfs.close();
-
- program_ = new framework::ProgramDesc(program_desc_str);
- GenerateLoadProgram(dirname);
-
- if (feed_var_names.empty() || fetch_var_names.empty()) {
- LOG(FATAL) << "Please specify the feed_var_names and fetch_var_names.";
- }
- feed_var_names_ = feed_var_names;
- fetch_var_names_ = fetch_var_names;
- PrependFeedOp();
- AppendFetchOp();
-}
-
bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
- if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
+ if (var->Persistable()) {
// There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i);
for (auto* op : block.AllOps()) {
+ if (op->Type() == "feed") {
+ continue;
+ }
for (auto input_argument_name : op->InputArgumentNames()) {
if (input_argument_name == var->Name()) {
return true;
@@ -182,7 +153,7 @@ void InferenceEngine::Execute(const std::vector& feeds,
LOG(FATAL) << "Please initialize the program_ and load_program_ first.";
}
- if (feeds.size() < feed_var_names_.size()) {
+ if (feeds.size() != feed_var_names_.size()) {
LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors.";
}
@@ -193,19 +164,22 @@ void InferenceEngine::Execute(const std::vector& feeds,
executor->Run(*load_program_, scope, 0, true, true);
+ std::map feed_targets;
+ std::map fetch_targets;
+
// set_feed_variable
for (size_t i = 0; i < feed_var_names_.size(); ++i) {
- framework::SetFeedVariable(scope, feeds[i], "feed", i);
+ feed_targets[feed_var_names_[i]] = &feeds[i];
}
- executor->Run(*program_, scope, 0, true, true);
-
// get_fetch_variable
fetchs.resize(fetch_var_names_.size());
for (size_t i = 0; i < fetch_var_names_.size(); ++i) {
- fetchs[i] = framework::GetFetchVariable(*scope, "fetch", i);
+ fetch_targets[fetch_var_names_[i]] = &fetchs[i];
}
+ executor->Run(*program_, scope, feed_targets, fetch_targets);
+
delete place;
delete scope;
delete executor;
diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h
index 7fc09cb9e539a65a8cd3cceb1543bc7d111c22b3..26f259824b945e260b370ced9d065842264075d5 100644
--- a/paddle/inference/inference.h
+++ b/paddle/inference/inference.h
@@ -29,9 +29,6 @@ public:
}
void LoadInferenceModel(const std::string& dirname);
- void LoadInferenceModel(const std::string& dirname,
- const std::vector& feed_var_names,
- const std::vector& fetch_var_names);
void Execute(const std::vector& feeds,
std::vector& fetchs);
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index 15f7cb6b560590f55e276fde4900d2e3c0045fb8..48cf5816cce4bb5ee8e66e72c5b1acea8535ab10 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -147,6 +147,7 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
+op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h
index 88c3d1c597a853abdee7753a5110be4a1726e905..c0809abc05104c1e8c1f42331c0530724dd1472f 100644
--- a/paddle/operators/activation_op.h
+++ b/paddle/operators/activation_op.h
@@ -323,7 +323,7 @@ template
struct FloorFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Out out) const {
- out.device(d) = x.ceil();
+ out.device(d) = x.floor();
}
};
diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc
index 1e41587c418fb0ce4e452d5c6735c54e2d42f798..9b5f7afc6a48f13ff999f635efeb9e7bf0a76fb5 100644
--- a/paddle/operators/detail/grpc_client.cc
+++ b/paddle/operators/detail/grpc_client.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "grpc_client.h"
+#include "paddle/framework/threadpool.h"
namespace paddle {
namespace operators {
namespace detail {
@@ -22,25 +23,32 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
- sendrecv::VariableMessage req;
- auto* var = scope.FindVar(var_name);
- SerializeToMessage(var_name, var, ctx, &req);
-
- // varhandle
- VarHandle var_h;
- var_h.ep = ep;
- var_h.scope = &scope;
- var_h.name = var_name;
- var_h.ctx = &ctx;
-
- // stub context
- auto ch = GetChannel(ep);
- SendProcessor* s = new SendProcessor(ch);
- s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
-
- auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
- rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ const platform::DeviceContext* p_ctx = &ctx;
+ const std::string ep_val = ep;
+ const std::string var_name_val = var_name;
+ const framework::Scope* p_scope = &scope;
+ const auto ch = GetChannel(ep_val);
+
+ framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] {
+ auto* var = p_scope->FindVar(var_name_val);
+ sendrecv::VariableMessage req;
+ SerializeToMessage(var_name_val, var, *p_ctx, &req);
+
+ // varhandle
+ VarHandle var_h;
+ var_h.ep = ep_val;
+ var_h.scope = p_scope;
+ var_h.name = var_name_val;
+ var_h.ctx = p_ctx;
+
+ // stub context
+ SendProcessor* s = new SendProcessor(ch);
+ s->Prepare(var_h, time_out);
+ s->response_call_back_ = NULL;
+
+ auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
+ rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ });
req_count_++;
@@ -50,8 +58,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const sendrecv::VariableMessage& ret_msg) {
auto* outvar = var_h.scope->FindVar(var_h.name);
-
- std::istringstream iss(ret_msg.serialized());
DeserializeFromMessage(ret_msg, *var_h.ctx, outvar);
}
@@ -60,44 +66,78 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out) {
+ const platform::DeviceContext* p_ctx = &ctx;
+ const std::string ep_val = ep;
+ const std::string var_name_val = var_name;
+ const framework::Scope* p_scope = &scope;
+ const auto ch = GetChannel(ep_val);
+
+ framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] {
+ sendrecv::VariableMessage req;
+ req.set_varname(var_name_val);
+
+ // varhandle
+ VarHandle var_h;
+ var_h.ep = ep_val;
+ var_h.scope = p_scope;
+ var_h.name = var_name_val;
+ var_h.ctx = p_ctx;
+
+ // stub context
+ GetProcessor* s = new GetProcessor(ch);
+ s->Prepare(var_h, time_out);
+ s->response_call_back_ = ProcGetResponse;
+
+ auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
+ rpc->Finish(&s->reply_, &s->status_, (void*)s);
+ });
+
+ req_count_++;
+
+ return true;
+}
+
+bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
+ const auto ch = GetChannel(ep);
+
+ BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
+ s->Prepare(time_out);
+
sendrecv::VariableMessage req;
- req.set_varname(var_name);
-
- // varhandle
- VarHandle var_h;
- var_h.ep = ep;
- var_h.scope = &scope;
- var_h.name = var_name;
- var_h.ctx = &ctx;
-
- // stub context
- auto ch = GetChannel(ep);
- GetProcessor* s = new GetProcessor(ch);
- s->Prepare(var_h, time_out);
- s->response_call_back_ = ProcGetResponse;
-
- auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
+ req.set_varname(BATCH_BARRIER_MESSAGE);
+ auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
-
req_count_++;
return true;
}
bool RPCClient::Wait() {
- bool ok = true;
+ if (req_count_ <= 0) {
+ return true;
+ }
+ const size_t kReqCnt = req_count_;
+ bool a[kReqCnt];
+ std::vector> waits(req_count_);
- while (true) {
- if (req_count_ <= 0) {
- break;
- }
+ for (int i = 0; i < req_count_; i++) {
+ waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); });
+ }
+
+ for (int i = 0; i < req_count_; i++) {
+ waits[i].wait();
+ }
- if (!Proceed()) {
+ int last_req_count = req_count_;
+ req_count_ = 0;
+
+ for (int i = 0; i < last_req_count; i++) {
+ if (!a[i]) {
return false;
}
}
- return ok;
+ return true;
}
bool RPCClient::Proceed() {
@@ -124,7 +164,6 @@ bool RPCClient::Proceed() {
c->Process();
delete c;
- req_count_--;
return true;
}
diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h
index a62e70a2533ae52d84d010504b19fed5aeb15dc0..f9499f6dc70c541c214e0b659f10b2ed1e8e8581 100644
--- a/paddle/operators/detail/grpc_client.h
+++ b/paddle/operators/detail/grpc_client.h
@@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline);
}
+ virtual void Prepare(int64_t time_out) {
+ context_.reset(new grpc::ClientContext());
+
+ std::chrono::system_clock::time_point deadline =
+ std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);
+
+ context_->set_deadline(deadline);
+ }
+
virtual void Process() = 0;
std::unique_ptr stub_;
@@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
+class BatchBarrierProcessor : public ClientBase {
+ public:
+ explicit BatchBarrierProcessor(std::shared_ptr ch)
+ : ClientBase(ch) {}
+
+ virtual ~BatchBarrierProcessor() {}
+
+ virtual void Process() {}
+ sendrecv::VoidMessage reply_;
+};
+
class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
@@ -130,6 +150,10 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);
+
+ bool AsyncSendBatchBarrier(const std::string& ep,
+ int64_t time_out = 600 * 1000);
+
bool Wait();
private:
diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc
index 3ddcd839bdd23547216465dfaf44a3cd8285fe6d..4f94e1315fbd2810a05354f7c3fc54ea30967e8a 100644
--- a/paddle/operators/detail/grpc_server.cc
+++ b/paddle/operators/detail/grpc_server.cc
@@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() {
cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();
+
server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;
@@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);
t_send_.reset(
- new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
+ new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));
t_get_.reset(
- new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
+ new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));
// wait server
@@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
- VLOG(4) << "create RequestSend status:" << send->Status();
+ VLOG(4) << "Create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne() {
@@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
- VLOG(4) << "create Requestget status:" << get->Status();
+ VLOG(4) << "Create RequestGet status:" << get->Status();
}
-// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
-void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
+// FIXME(typhoonzero): change cq_name to enum.
+void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function TryToRegisterNewOne) {
TryToRegisterNewOne();
diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h
index 1ca9086c744c558fd05fb4fc1a7280729afbec28..3f8b9d93176148619d6820f6a365d9da2e73b10d 100644
--- a/paddle/operators/detail/grpc_server.h
+++ b/paddle/operators/detail/grpc_server.h
@@ -57,8 +57,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {
void ShutDown();
protected:
- void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
- std::string cq_name,
+ void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
diff --git a/paddle/operators/detail/sendrecvop_utils.h b/paddle/operators/detail/sendrecvop_utils.h
index bc6581afab93c626c7c2439d699c6c2d858df9fa..8e66f7299c7b4d30bc5a6fe6a18b7cb3ae3827a5 100644
--- a/paddle/operators/detail/sendrecvop_utils.h
+++ b/paddle/operators/detail/sendrecvop_utils.h
@@ -30,6 +30,9 @@ namespace paddle {
namespace operators {
namespace detail {
+#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
+#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
+
void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
diff --git a/paddle/operators/gru_op.cc b/paddle/operators/gru_op.cc
index 76f2adefede3b4bc4035f86f8f8663eed29343ae..fb901b639492a179925ff852f9030fc6674d1f63 100644
--- a/paddle/operators/gru_op.cc
+++ b/paddle/operators/gru_op.cc
@@ -135,14 +135,14 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
-\f[
-update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
-reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
-output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
+$$
+update\_gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
+reset\_gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
+output\_candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
-\f]
+$$
-@note To implement the complete GRU, fully-connected operator must be used
+@note To implement the complete GRU, fully-connected operator must be used
before to feed xu, xr and xc as the Input of GRU operator.
)DOC");
}
diff --git a/paddle/operators/im2sequence_op.h b/paddle/operators/im2sequence_op.h
index aeb810015134babc132909b3e820fa8391233b1c..f33aec71a92a65ec0e4114530d70e36c9dc1be04 100644
--- a/paddle/operators/im2sequence_op.h
+++ b/paddle/operators/im2sequence_op.h
@@ -79,7 +79,7 @@ class Im2SequenceKernel : public framework::OpKernel {
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
- lod[0][i] = offset;
+ lod[0].push_back(offset);
offset += output_height * output_width;
}
out->set_lod(lod);
diff --git a/paddle/operators/lstmp_op.cc b/paddle/operators/lstmp_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..c96b30ba353fabc48630258ea8f88f741b8c415e
--- /dev/null
+++ b/paddle/operators/lstmp_op.cc
@@ -0,0 +1,331 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/lstmp_op.h"
+
+namespace paddle {
+namespace operators {
+
+class LSTMPOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Input"),
+ "Input(Input) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Weight"),
+ "Input(Weight) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
+ "Input(ProjWeight) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Bias"),
+ "Input(Bias) of LSTMP operator should not be null.");
+
+ PADDLE_ENFORCE(ctx->HasOutput("Projection"),
+ "Output(Projection) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Cell"),
+ "Output(Cell) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
+ "Output(BatchGate) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
+ "Output(BatchCellPreAct) of LSTMP operator should not be "
+ "null.");
+ PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
+ "Output(BatchHidden) of LSTMP operator should not be null.");
+
+ auto in_dims = ctx->GetInputDim("Input");
+ PADDLE_ENFORCE_EQ(in_dims.size(), 2,
+ "Input(X)'s rank of LSTMP operator must be 2.");
+
+ int frame_size = in_dims[1] / 4;
+ auto w_dims = ctx->GetInputDim("Weight");
+ auto proj_dims = ctx->GetInputDim("ProjWeight");
+ PADDLE_ENFORCE_EQ(w_dims.size(), 2,
+ "The rank of Input(Weight) should be 2.");
+ PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1],
+ "The first dimension of Input(Weight) "
+ "should be %d.",
+ proj_dims[1]);
+ PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
+ "The second dimension of Input(Weight) "
+ "should be 4 * %d.",
+ frame_size);
+
+ PADDLE_ENFORCE_EQ(proj_dims.size(), 2,
+ "The rank of Input(ProjWeight) should be 2.");
+ PADDLE_ENFORCE_EQ(proj_dims[0], frame_size,
+ "The first dimension of Input(ProjWeight) "
+ "should be %d.",
+ frame_size);
+
+ if (ctx->HasInput("H0")) {
+ PADDLE_ENFORCE(ctx->HasInput("C0"),
+ "Input(C0) of LSTMP operator should not be null after "
+ "Input(H0) provided.");
+ auto h_dims = ctx->GetInputDim("H0");
+ auto c_dims = ctx->GetInputDim("C0");
+ PADDLE_ENFORCE(h_dims == c_dims,
+ "The dimension of Input(H0) and Input(C0) "
+ "should be the same.");
+ ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]});
+ }
+
+ auto b_dims = ctx->GetInputDim("Bias");
+ PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
+ PADDLE_ENFORCE_EQ(b_dims[0], 1,
+ "The first dimension of Input(Bias) should be 1.");
+
+ if (ctx->Attrs().Get("use_peepholes")) {
+ PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
+ "The second dimension of Input(Bias) should be "
+ "7 * %d if enable peepholes connection",
+ frame_size);
+ } else {
+ PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
+ "The second dimension of Input(Bias) should be "
+ "4 * %d if disable peepholes connection",
+ frame_size);
+ }
+
+ framework::DDim out_dims({in_dims[0], frame_size});
+ framework::DDim proj_out_dims({in_dims[0], proj_dims[1]});
+ ctx->SetOutputDim("Projection", proj_out_dims);
+ ctx->SetOutputDim("Cell", out_dims);
+ ctx->SetOutputDim("BatchGate", in_dims);
+ ctx->SetOutputDim("BatchCellPreAct", out_dims);
+ ctx->SetOutputDim("BatchHidden", out_dims);
+ ctx->ShareLoD("Input", "Projection");
+ ctx->ShareLoD("Input", "Cell");
+ }
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ framework::ToDataType(ctx.Input("Input")->type()),
+ ctx.device_context());
+ }
+};
+
+class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ LSTMPOpMaker(OpProto* proto, OpAttrChecker* op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("Input",
+ "(LoDTensor) the input for sequence data, which supports "
+ "variable-time length input sequence. The underlying tensor in "
+ "this LoDTensor is a matrix with shape (T X 4D), where T is the "
+ "total time steps in this mini-batch, D is the hidden size.");
+ AddInput("H0",
+ "(Tensor, optional) the initial hidden state is an optional "
+ "input. This is a tensor with shape (N x D), where N is the "
+ "batch size and D is the hidden size.")
+ .AsDispensable();
+ AddInput("C0",
+ "(Tensor, optional) the initial cell state is an optional "
+ "input. This is a tensor with shape (N x D), where N is the "
+ "batch size. `C0` should not be null if `H0` provided.")
+ .AsDispensable();
+ AddInput("Weight",
+ "(Tensor) the learnable hidden-hidden weights."
+ " - The shape is (P x 4D), where P is the projection layer size "
+ "and D is the hidden size."
+ " - Weight = {W_cr, W_ir, W_fr, W_or}");
+ AddInput("ProjWeight",
+ "(Tensor) the learnable weight of the projection layer."
+ " - The shape is (D x P), where P is the recurrent projection "
+ "layer size and D is the hidden size."
+ " - ProjWeight = {W_rh}");
+ AddInput("Bias",
+ "(Tensor) the learnable biases, which contains two parts: "
+ "input-hidden biases and peephole connections weights if "
+ "setting `use_peepholes` to `True`. "
+ "1. `use_peepholes = False` "
+ " - The shape is (1 x 4D). "
+ " - Bias = {b_c, b_i, b_f, b_o}."
+ "2. `use_peepholes = True` "
+ " - The shape is (1 x 7D). "
+ " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
+ AddOutput("Projection",
+ "(LoDTensor) the projection of the hidden state of LSTMP "
+ "operator. The shape is (T x P), and LoD is the same with the "
+ "`Input`.");
+ AddOutput("Cell",
+ "(LoDTensor) the cell state of LSTMP operator. "
+ "The shape is (T x D), and lod is the same with the `Input`.");
+ AddOutput("BatchGate",
+ "(LoDTensor) This LoDTensor contains input gate, forget gate "
+ "and output gate after the activations. This LoDTensor has the "
+ "same shape as the reorganized input, which is also be called "
+ "batch input. The LoD size is 2. The first-level LoD is the "
+ "batch offsets and the second contains the indices, which "
+ "denotes the position of reorganized sequence in the raw input.")
+ .AsIntermediate();
+ AddOutput("BatchCellPreAct",
+ "(LoDTensor) the pre-activation cell state reorganized in batch. "
+ "This LoDTensor is obtained in the forward and used in the "
+ "backward.")
+ .AsIntermediate();
+ AddOutput("BatchHidden",
+ "(LoDTensor) the hidden state reorganized in batch. "
+ "This LoDTensor is obtained in the forward and used in the "
+ "backward.")
+ .AsIntermediate();
+ AddOutput("OrderedP0",
+ "(Tensor) the projection of the initial hidden state "
+ "H0. This is a tensor with shape (N x P), where N is the "
+ "batch size and P is the hidden size.")
+ .AsIntermediate();
+ AddAttr("use_peepholes",
+ "(bool, defalut: True) "
+ "whether to enable diagonal/peephole connections.")
+ .SetDefault(true);
+ AddAttr("is_reverse",
+ "(bool, defalut: False) "
+ "whether to compute reversed LSTMP.")
+ .SetDefault(false);
+ AddAttr(
+ "gate_activation",
+ "(string, default: sigmoid)"
+ "The activation for input gate, forget gate and output "
+ "gate, `sigmoid` by default.")
+ .SetDefault("sigmoid")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddAttr("cell_activation",
+ "(string, default: tanh)"
+ "The activation for cell output, `tanh` by defalut.")
+ .SetDefault("tanh")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddAttr("candidate_activation",
+ "(string, default: tanh)"
+ "The activation for candidate hidden state, "
+ "`tanh` by default.")
+ .SetDefault("tanh")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddAttr("proj_activation",
+ "(string, default: tanh)"
+ "The activation for projection output, "
+ "`tanh` by defalut.")
+ .SetDefault("tanh")
+ .InEnum({"sigmoid", "tanh", "relu", "identity"});
+ AddComment(R"DOC(
+Long-Short Term Memory with recurrent Projection layer (LSTMP) Operator.
+
+LSTMP has a separate projection layer after the LSTM layer, projecting the
+original hidden state to a lower-dimensional one, which is proposed to reduce
+the number of total parameters and furthermore computational complexity for
+the LSTM, espeacially for the case that the size of output units is relative
+large (https://research.google.com/pubs/archive/43905.pdf).
+
+The formula is as follows:
+
+$$
+i_t = \sigma(W_{ix}x_{t} + W_{ir}r_{t-1} + W_{ic}c_{t-1} + b_i) \\
+
+f_t = \sigma(W_{fx}x_{t} + W_{fr}r_{t-1} + W_{fc}c_{t-1} + b_f) \\
+
+\tilde{c_t} = act_g(W_{cx}x_t + W_{cr}r_{t-1} + b_c) \\
+
+o_t = \sigma(W_{ox}x_{t} + W_{or}r_{t-1} + W_{oc}c_t + b_o) \\
+
+c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c_t} \\
+
+h_t = o_t \odot act_h(c_t) \\
+
+r_t = \overline{act_h}(W_{rh}h_t)
+$$
+
+where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
+of weights from the input gate to the input), $W_{ic}, W_{fc}, W_{oc}$
+are diagonal weight matrices for peephole connections. In our implementation,
+we use vectors to reprenset these diagonal weight matrices. The b terms
+denote bias vectors ($b_i$ is the input gate bias vector), $\sigma$
+is the activation, such as logistic sigmoid function, and
+$i, f, o$ and $c$ are the input gate, forget gate, output gate,
+and cell activation vectors, respectively, all of which have the same size as
+the cell output activation vector $h$. Here $h$ is usually called the hidden
+state and $r$ denotes its recurrent projection. And $\tilde{c_t}$ is also
+called the candidate hidden state, whose computation is based on the current
+input and previous hidden state.
+
+The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
+are the cell input and cell output activation functions and `tanh` is usually
+used for them. $\overline{act_h}$ is the activation function for the
+projection output, usually using `identity` or same as $act_h$.
+
+Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
+operations on the input $x_{t}$ are NOT included in this operator.
+Users can choose to use fully-connected operator before LSTMP operator.
+
+)DOC");
+ }
+};
+
+class LSTMPGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("Input"),
+ "Input(Input) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Projection"),
+ "Input(Projection) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Cell"),
+ "Input(Cell) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Weight"),
+ "Input(Weight) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
+ "Input(ProjWeight) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("Bias"),
+ "Input(Bias) of LSTMP operator should not be null.");
+
+ PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
+ "Input(BatchGate) of LSTMP operator should not be null.");
+ PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
+ "Input(BatchGate) of LSTMP operator should not be null.");
+
+ auto SetOutGradDim = [&ctx](const std::string& name) {
+ auto g_name = framework::GradVarName(name);
+ if (ctx->HasOutput(g_name))
+ ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
+ };
+
+ SetOutGradDim("Input");
+ SetOutGradDim("Weight");
+ SetOutGradDim("ProjWeight");
+ SetOutGradDim("Bias");
+ SetOutGradDim("H0");
+ SetOutGradDim("C0");
+ }
+
+ protected:
+ framework::OpKernelType GetExpectedKernelType(
+ const framework::ExecutionContext& ctx) const override {
+ return framework::OpKernelType(
+ framework::ToDataType(ctx.Input("Input")->type()),
+ ctx.device_context());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(lstmp, ops::LSTMPOp, ops::LSTMPOpMaker, lstmp_grad,
+ ops::LSTMPGradOp);
+REGISTER_OP_CPU_KERNEL(
+ lstmp, ops::LSTMPKernel,
+ ops::LSTMPKernel);
+REGISTER_OP_CPU_KERNEL(
+ lstmp_grad, ops::LSTMPGradKernel,
+ ops::LSTMPGradKernel);
diff --git a/paddle/operators/lstmp_op.cu b/paddle/operators/lstmp_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..7fcbcfecc871976fdfbfffbbb4e0243b91351a29
--- /dev/null
+++ b/paddle/operators/lstmp_op.cu
@@ -0,0 +1,24 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/lstmp_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+ lstmp, ops::LSTMPKernel,
+ ops::LSTMPKernel);
+REGISTER_OP_CUDA_KERNEL(
+ lstmp_grad,
+ ops::LSTMPGradKernel,
+ ops::LSTMPGradKernel);
diff --git a/paddle/operators/lstmp_op.h b/paddle/operators/lstmp_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..ee82d5c10a5421b181e525f49a263d4808ede62f
--- /dev/null
+++ b/paddle/operators/lstmp_op.h
@@ -0,0 +1,491 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/operators/activation_op.h"
+#include "paddle/operators/math/detail/activation_functions.h"
+#include "paddle/operators/math/lstm_compute.h"
+#include "paddle/operators/math/math_function.h"
+#include "paddle/operators/math/sequence2batch.h"
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using LoDTensor = framework::LoDTensor;
+using Tensor = framework::Tensor;
+
+template
+using EigenMatrix = framework::EigenMatrix;
+
+template
+inline void ReorderInitState(const DeviceContext& ctx,
+ const framework::Tensor& src, const size_t* index,
+ framework::Tensor* dst, bool indexed_src) {
+ math::CopyMatrixRowsFunctor row_shuffle;
+ dst->mutable_data(src.dims(), ctx.GetPlace());
+ row_shuffle(ctx, src, index, *dst, indexed_src);
+}
+
+template
+class LSTMPKernel : public framework::OpKernel {
+ public:
+ template
+ void ActCompute(const math::detail::ActivationType act_type, const Device& d,
+ X x, Y y) const {
+ if (act_type == math::detail::ActivationType::kIdentity)
+ y.device(d) = x;
+ else if (act_type == math::detail::ActivationType::kSigmoid)
+ SigmoidFunctor()(d, x, y);
+ else if (act_type == math::detail::ActivationType::kTanh)
+ TanhFunctor()(d, x, y);
+ else if (act_type == math::detail::ActivationType::kReLU)
+ ReluFunctor()(d, x, y);
+ else
+ PADDLE_THROW("unsupported activation type");
+ }
+
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* input = ctx.Input("Input");
+ auto* weight = ctx.Input("Weight");
+ auto* proj_weight = ctx.Input("ProjWeight");
+ auto* bias = ctx.Input("Bias");
+
+ auto* hidden_t0 = ctx.Input("H0");
+ auto* ordered_proj0 = ctx.Output("OrderedP0");
+ auto* cell_t0 = ctx.Input("C0");
+
+ auto* batch_gate = ctx.Output("BatchGate");
+ batch_gate->mutable_data(ctx.GetPlace());
+ auto* proj_out = ctx.Output("Projection");
+ proj_out->mutable_data(ctx.GetPlace());
+ auto* cell_out = ctx.Output("Cell");
+ cell_out->mutable_data(ctx.GetPlace());
+
+ bool is_reverse = ctx.Attr("is_reverse");
+ math::LoDTensor2BatchFunctor to_batch;
+ auto& device_ctx = ctx.template device_context();
+ to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
+
+ auto in_dims = input->dims();
+ int frame_size = static_cast(in_dims[1] / 4);
+ framework::DDim dims({in_dims[0], frame_size});
+ framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]});
+
+ if (bias) {
+ Tensor b = *bias;
+ b.Resize({bias->numel(), 1});
+ Tensor gate_bias = b.Slice(0, 4 * frame_size);
+ math::RowwiseAdd add_bias;
+ add_bias(device_ctx, *batch_gate, gate_bias, batch_gate);
+ }
+
+ math::LstmMetaValue lstmp_value;
+ if (bias && ctx.Attr("use_peepholes")) {
+ T* bias_data = const_cast(bias->data());
+ // the code style in LstmpMetaValue will be updated later.
+
+ lstmp_value.check_ig = bias_data + 4 * frame_size;
+ lstmp_value.check_fg = lstmp_value.check_ig + frame_size;
+ lstmp_value.check_og = lstmp_value.check_fg + frame_size;
+ } else {
+ lstmp_value.check_ig = nullptr;
+ lstmp_value.check_fg = nullptr;
+ lstmp_value.check_og = nullptr;
+ }
+ lstmp_value.prev_state_value = nullptr;
+ Tensor ordered_c0;
+ const size_t* order = batch_gate->lod()[2].data();
+ if (cell_t0) {
+ // Since the batch computing for LSTMP reorders the input sequence
+ // according to their length. The initialized cell state also needs
+ // to reorder.
+ ReorderInitState(device_ctx, *cell_t0, order,
+ &ordered_c0, true);
+ lstmp_value.prev_state_value = ordered_c0.data();
+ }
+
+ // Use the local variable as here.
+ LoDTensor batch_proj, batch_cell;
+ auto* batch_cell_pre_act = ctx.Output("BatchCellPreAct");
+ batch_cell_pre_act->mutable_data(dims, ctx.GetPlace());
+ auto* batch_hidden = ctx.Output("BatchHidden");
+ batch_hidden->mutable_data(dims, ctx.GetPlace()); // T x D
+ batch_proj.mutable_data(proj_dims, ctx.GetPlace()); // T x P
+ batch_cell.mutable_data(dims, ctx.GetPlace()); // T x D
+
+ auto batch_starts = batch_gate->lod()[0];
+ size_t num_batch = batch_starts.size() - 1;
+ auto gate_act = math::detail::GetActivationType(
+ ctx.Attr("gate_activation"));
+ auto cell_act = math::detail::GetActivationType(
+ ctx.Attr("cell_activation"));
+ auto cand_act = math::detail::GetActivationType(
+ ctx.Attr("candidate_activation"));
+ auto proj_act = math::detail::GetActivationType(
+ ctx.Attr("proj_activation"));
+ auto& place = *ctx.template device_context().eigen_device();
+
+ for (size_t n = 0; n < num_batch; n++) {
+ int bstart = static_cast(batch_starts[n]);
+ int bend = static_cast(batch_starts[n + 1]);
+
+ Tensor gate_t = batch_gate->Slice(bstart, bend);
+ Tensor hidden_t = batch_hidden->Slice(bstart, bend);
+ Tensor proj_t = batch_proj.Slice(bstart, bend);
+ Tensor cell_t = batch_cell.Slice(bstart, bend);
+ Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
+
+ int cur_batch_size = bend - bstart;
+
+ if (n > 0) {
+ int pre_h_start = static_cast(batch_starts[n - 1]);
+ int pre_h_end = pre_h_start + cur_batch_size;
+ auto pre_proj_t = batch_proj.Slice(pre_h_start, pre_h_end);
+ math::matmul(device_ctx, pre_proj_t, false, *weight,
+ false, static_cast(1.0), &gate_t,
+ static_cast(1.0));
+ } else if (hidden_t0) {
+ // If n == 0 and there is no initialized hidden state, that is to say
+ // the H0 is zeros, the calculation W_h * H0 will be skiped.
+ // If n == 0 and there is initialized hidden state, calculate W_h * H0.
+
+ // Since the batch computing for LSTMP reorders the input sequence
+ // according to their length. The initialized hidden state also needs
+ // to reorder.
+
+ Tensor ordered_h0;
+ ordered_proj0->mutable_data(ctx.GetPlace());
+ ReorderInitState(device_ctx, *hidden_t0, order,
+ &ordered_h0, true);
+ math::matmul(device_ctx, ordered_h0, false,
+ *proj_weight, false, static_cast(1.0),
+ ordered_proj0, static_cast(0.0));
+ if (proj_act != math::detail::ActivationType::kIdentity) {
+ auto proj0_dev = EigenMatrix::From(*ordered_proj0);
+ ActCompute(cell_act, place, proj0_dev, proj0_dev);
+ }
+ math::matmul(device_ctx, *ordered_proj0, false,
+ *weight, false, static_cast(1.0),
+ &gate_t, static_cast(1.0));
+ }
+
+ lstmp_value.gate_value = gate_t.data();
+ lstmp_value.output_value = hidden_t.data();
+ lstmp_value.state_value = cell_t.data();
+ lstmp_value.state_active_value = cell_pre_act_t.data();
+ math::LstmUnitFunctor::compute(
+ device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act,
+ cell_act, cand_act);
+ lstmp_value.prev_state_value = lstmp_value.state_value;
+ math::matmul(device_ctx, hidden_t, false, *proj_weight,
+ false, static_cast(1.0), &proj_t,
+ static_cast(0.0));
+ if (proj_act != math::detail::ActivationType::kIdentity) {
+ auto proj_t_dev = EigenMatrix::From(proj_t);
+ ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
+ }
+ }
+
+ math::Batch2LoDTensorFunctor to_seq;
+ batch_proj.set_lod(batch_gate->lod());
+ // restore the output hidden in LoDTensor from the batch hidden
+ to_seq(device_ctx, batch_proj, *proj_out);
+
+ batch_cell.set_lod(batch_gate->lod());
+ // restore the output cell state in LoDTensor from the batch cell
+ to_seq(device_ctx, batch_cell, *cell_out);
+ }
+};
+
+template
+class LSTMPGradKernel : public framework::OpKernel {
+ public:
+ template
+ void ActGradCompute(const math::detail::ActivationType act_type,
+ const Device& d, X x, Y y, DX dx, DY dy) const {
+ // x is dummy and won't be used even in Relu(use y instead)
+ if (act_type == math::detail::ActivationType::kIdentity)
+ dx.device(d) = dy;
+ else if (act_type == math::detail::ActivationType::kSigmoid)
+ SigmoidGradFunctor()(d, x, y, dy, dx);
+ else if (act_type == math::detail::ActivationType::kTanh)
+ TanhGradFunctor()(d, x, y, dy, dx);
+ else if (act_type == math::detail::ActivationType::kReLU)
+ ReluGradFunctor()(d, x, y, dy, dx);
+ else
+ PADDLE_THROW("unsupported activation type");
+ }
+
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto* input = ctx.Input("Input");
+ auto* weight = ctx.Input("Weight");
+ auto* proj_weight = ctx.Input("ProjWeight");
+ auto* bias = ctx.Input("Bias");
+
+ auto* proj_out = ctx.Input("Projection");
+ auto* cell_out = ctx.Input("Cell");
+
+ auto* batch_gate = ctx.Input("BatchGate");
+ auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct");
+ auto* batch_hidden = ctx.Input("BatchHidden");
+
+ auto* projection_g =
+ ctx.Input