diff --git a/doc/getstarted/quickstart_cn.rst b/doc/getstarted/quickstart_cn.rst
index 51dd00f1e806e6423afe3ce53d80d53a187d2ca0..d511cead262dabafd095f68adb5ffc596a7fe596 100644
--- a/doc/getstarted/quickstart_cn.rst
+++ b/doc/getstarted/quickstart_cn.rst
@@ -1,6 +1,9 @@
快速开始
========
+快速安装
+--------
+
PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。
执行下面的命令完成快速安装,版本为cpu_avx_openblas:
@@ -16,6 +19,9 @@ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.
更详细的安装和编译方法参考::ref:`install_steps` 。
+快速使用
+--------
+
创建一个 housing.py 并粘贴此Python代码:
.. code-block:: python
diff --git a/doc/getstarted/quickstart_en.rst b/doc/getstarted/quickstart_en.rst
index d1bcf82ea071e2c53760a5ccf6a5074a3ac0abd5..70f7fe0646068aa79cd72955c6848ac0250c2300 100644
--- a/doc/getstarted/quickstart_en.rst
+++ b/doc/getstarted/quickstart_en.rst
@@ -1,6 +1,9 @@
Quick Start
============
+Quick Install
+-------------
+
You can use pip to install PaddlePaddle with a single command, supports
CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed.
Simply run the following command to install, the version is cpu_avx_openblas:
@@ -17,6 +20,9 @@ If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run:
For more details about installation and build: :ref:`install_steps` .
+Quick Use
+---------
+
Create a new file called housing.py, and paste this Python
code:
diff --git a/doc/howto/cluster/index_cn.rst b/doc/howto/cluster/index_cn.rst
index c68b2655b65b192814b94f0013fa92b0733b9afa..a60521b4a9646bdc6d9f1bf6da482acc989d8bf3 100644
--- a/doc/howto/cluster/index_cn.rst
+++ b/doc/howto/cluster/index_cn.rst
@@ -1,10 +1,22 @@
分布式训练
==========
+本节将介绍如何使用PaddlePaddle在不同的集群框架下完成分布式训练。分布式训练架构如下图所示:
+
+.. image:: src/ps_cn.png
+ :width: 500
+
+- 数据分片(Data shard): 用于训练神经网络的数据,被切分成多个部分,每个部分分别给每个trainer使用。
+- 计算节点(Trainer): 每个trainer启动后读取切分好的一部分数据,开始神经网络的“前馈”和“后馈”计算,并和参数服务器通信。在完成一定量数据的训练后,上传计算得出的梯度(gradients),然后下载优化更新后的神经网络参数(parameters)。
+- 参数服务器(Parameter server):每个参数服务器只保存整个神经网络所有参数的一部分。参数服务器接收从计算节点上传的梯度,并完成参数优化更新,再将更新后的参数下发到每个计算节点。
+
+这样,通过计算节点和参数服务器的分布式协作,可以完成神经网络的SGD方法的训练。PaddlePaddle可以同时支持同步随机梯度下降(SGD)和异步随机梯度下降。
+
+在使用同步SGD训练神经网络时,PaddlePaddle使用同步屏障(barrier),使梯度的提交和参数的更新按照顺序方式执行。在异步SGD中,则并不会等待所有trainer提交梯度才更新参数,这样极大地提高了计算的并行性:参数服务器之间不相互依赖,并行地接收梯度和更新参数,参数服务器也不会等待计算节点全部都提交梯度之后才开始下一步,计算节点之间也不会相互依赖,并行地执行模型的训练。可以看出,虽然异步SGD方式会提高参数更新并行度, 但是并不能保证参数同步更新,在任意时间某一台参数服务器上保存的参数可能比另一台要更新,与同步SGD相比,梯度会有噪声。
+
.. toctree::
:maxdepth: 1
- introduction_cn.md
preparations_cn.md
cmd_argument_cn.md
multi_cluster/index_cn.rst
diff --git a/doc/howto/cluster/index_en.rst b/doc/howto/cluster/index_en.rst
index af957e06cd7930ce63569a1bafdde47a1d34eb69..2640a09dcc904619bc97c9bd3f3d81a9dc307663 100644
--- a/doc/howto/cluster/index_en.rst
+++ b/doc/howto/cluster/index_en.rst
@@ -1,10 +1,22 @@
Distributed Training
====================
+In this section, we'll explain how to run distributed training jobs with PaddlePaddle on different types of clusters. The diagram below shows the main architecture of a distributed trainning job:
+
+.. image:: src/ps_en.png
+ :width: 500
+
+- Data shard: training data will be split into multiple partitions, trainers use the partitions of the whole dataset to do the training job.
+- Trainer: each trainer reads the data shard, and train the neural network. Then the trainer will upload calculated "gradients" to parameter servers, and wait for parameters to be optimized on the parameter server side. When that finishes, the trainer download optimized parameters and continues its training.
+- Parameter server: every parameter server stores part of the whole neural network model data. They will do optimization calculations when gradients are uploaded from trainers, and then send updated parameters to trainers.
+
+PaddlePaddle can support both synchronize stochastic gradient descent (SGD) and asynchronous SGD.
+
+When training with synchronize SGD, PaddlePaddle uses an internal "synchronize barrier" which makes gradients update and parameter download in strict order. On the other hand, asynchronous SGD won't wait for all trainers to finish upload at a single step, this will increase the parallelism of distributed training: parameter servers do not depend on each other, they'll do parameter optimization concurrently. Parameter servers will not wait for trainers, so trainers will also do their work concurrently. But asynchronous SGD will introduce more randomness and noises in the gradient.
+
.. toctree::
:maxdepth: 1
- introduction_en.md
preparations_en.md
cmd_argument_en.md
multi_cluster/index_en.rst
diff --git a/doc/howto/cluster/introduction_cn.md b/doc/howto/cluster/introduction_cn.md
deleted file mode 100644
index 562008a898414a6566d74d08cfeb18fb9d57582a..0000000000000000000000000000000000000000
--- a/doc/howto/cluster/introduction_cn.md
+++ /dev/null
@@ -1,13 +0,0 @@
-## 概述
-
-本节将介绍如何使用PaddlePaddle在不同的集群框架下完成分布式训练。分布式训练架构如下图所示:
-
-
-
-- 数据分片(Data shard): 用于训练神经网络的数据,被切分成多个部分,每个部分分别给每个trainer使用。
-- 计算节点(Trainer): 每个trainer启动后读取切分好的一部分数据,开始神经网络的“前馈”和“后馈”计算,并和参数服务器通信。在完成一定量数据的训练后,上传计算得出的梯度(gradients),然后下载优化更新后的神经网络参数(parameters)。
-- 参数服务器(Parameter server):每个参数服务器只保存整个神经网络所有参数的一部分。参数服务器接收从计算节点上传的梯度,并完成参数优化更新,再将更新后的参数下发到每个计算节点。
-
-这样,通过计算节点和参数服务器的分布式协作,可以完成神经网络的SGD方法的训练。PaddlePaddle可以同时支持同步随机梯度下降(SGD)和异步随机梯度下降。
-
-在使用同步SGD训练神经网络时,PaddlePaddle使用同步屏障(barrier),使梯度的提交和参数的更新按照顺序方式执行。在异步SGD中,则并不会等待所有trainer提交梯度才更新参数,这样极大地提高了计算的并行性:参数服务器之间不相互依赖,并行地接收梯度和更新参数,参数服务器也不会等待计算节点全部都提交梯度之后才开始下一步,计算节点之间也不会相互依赖,并行地执行模型的训练。可以看出,虽然异步SGD方式会提高参数更新并行度, 但是并不能保证参数同步更新,在任意时间某一台参数服务器上保存的参数可能比另一台要更新,与同步SGD相比,梯度会有噪声。
diff --git a/doc/howto/cluster/introduction_en.md b/doc/howto/cluster/introduction_en.md
deleted file mode 100644
index eb70d7cf35ab729e0da4c6a3a2e732c26905f584..0000000000000000000000000000000000000000
--- a/doc/howto/cluster/introduction_en.md
+++ /dev/null
@@ -1,13 +0,0 @@
-## Introduction
-
-In this section, we'll explain how to run distributed training jobs with PaddlePaddle on different types of clusters. The diagram below shows the main architecture of a distributed trainning job:
-
-
-
-- Data shard: training data will be split into multiple partitions, trainers use the partitions of the whole dataset to do the training job.
-- Trainer: each trainer reads the data shard, and train the neural network. Then the trainer will upload calculated "gradients" to parameter servers, and wait for parameters to be optimized on the parameter server side. When that finishes, the trainer download optimized parameters and continues its training.
-- Parameter server: every parameter server stores part of the whole neural network model data. They will do optimization calculations when gradients are uploaded from trainers, and then send updated parameters to trainers.
-
-PaddlePaddle can support both synchronize stochastic gradient descent (SGD) and asynchronous SGD.
-
-When training with synchronize SGD, PaddlePaddle uses an internal "synchronize barrier" which makes gradients update and parameter download in strict order. On the other hand, asynchronous SGD won't wait for all trainers to finish upload at a single step, this will increase the parallelism of distributed training: parameter servers do not depend on each other, they'll do parameter optimization concurrently. Parameter servers will not wait for trainers, so trainers will also do their work concurrently. But asynchronous SGD will introduce more randomness and noises in the gradient.
diff --git a/doc/howto/cluster/src/ps_cn.png b/doc/howto/cluster/src/ps_cn.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9525739cc8bc6506adde642aafa0a85ae3ebebc
Binary files /dev/null and b/doc/howto/cluster/src/ps_cn.png differ
diff --git a/doc/howto/cluster/src/ps_en.png b/doc/howto/cluster/src/ps_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..6537d3d56589ca9f19a77a50a970e4b5275e6ce0
Binary files /dev/null and b/doc/howto/cluster/src/ps_en.png differ
diff --git a/doc/howto/rnn/index_cn.rst b/doc/howto/rnn/index_cn.rst
index 9ecab5594cff47cde4700b7ce0f58013a960a16e..bcc8c2f46eb662ec3650e829a77992224dbbb8e7 100644
--- a/doc/howto/rnn/index_cn.rst
+++ b/doc/howto/rnn/index_cn.rst
@@ -1,4 +1,4 @@
-RNN相关模型
+RNN模型
===========
.. toctree::
diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc
index dd2ed87252102aee6d384f37365d19305f19b281..3e344ea3790f57b0f53f36a40263dcdd326e67a9 100644
--- a/paddle/framework/block_desc.cc
+++ b/paddle/framework/block_desc.cc
@@ -162,9 +162,8 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
: prog_(prog), desc_(desc) {
need_update_ = true;
for (auto &op : other.ops_) {
- ops_.emplace_back(new OpDesc(*op, this));
+ ops_.emplace_back(new OpDesc(*op->Proto(), prog, this));
}
-
for (auto &it : other.vars_) {
auto *var = new VarDesc(*it.second);
vars_[it.first].reset(var);
diff --git a/paddle/framework/channel.h b/paddle/framework/channel.h
index b679387b1124e42499df158767b6c7afe1afd0c6..146f0e9e71ea9101a8f6c71e6c023178f131f967 100644
--- a/paddle/framework/channel.h
+++ b/paddle/framework/channel.h
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/paddle/framework/channel_test.cc b/paddle/framework/channel_test.cc
index df9e15e22b890347a03d6816e8549c99b010bb38..d7140dd10661c7b8582930b47872ab0b330c4d66 100644
--- a/paddle/framework/channel_test.cc
+++ b/paddle/framework/channel_test.cc
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -22,6 +22,28 @@ limitations under the License. */
using paddle::framework::Channel;
using paddle::framework::MakeChannel;
using paddle::framework::CloseChannel;
+using paddle::framework::details::Buffered;
+using paddle::framework::details::UnBuffered;
+
+void RecevingOrderEqualToSendingOrder(Channel *ch) {
+ unsigned sum_send = 0;
+ std::thread t([&]() {
+ for (int i = 0; i < 5; i++) {
+ EXPECT_EQ(ch->Send(&i), true);
+ sum_send += i;
+ }
+ });
+ for (int i = 0; i < 5; i++) {
+ int recv;
+ EXPECT_EQ(ch->Receive(&recv), true);
+ EXPECT_EQ(recv, i);
+ }
+
+ CloseChannel(ch);
+ t.join();
+ EXPECT_EQ(sum_send, 10U);
+ delete ch;
+}
TEST(Channel, MakeAndClose) {
using paddle::framework::details::Buffered;
@@ -60,13 +82,54 @@ TEST(Channel, SufficientBufferSizeDoesntBlock) {
delete ch;
}
-TEST(Channel, SendOnClosedChannelPanics) {
- const size_t buffer_size = 10;
- auto ch = MakeChannel(buffer_size);
- size_t i = 5;
- EXPECT_EQ(ch->Send(&i), true); // should not block or panic
+// This tests that a channel must return false
+// on send and receive performed after closing the channel.
+// Receive will only return false after close when queue is empty.
+// By creating separate threads for sending and receiving, we make this
+// function able to test both buffered and unbuffered channels.
+void SendReceiveWithACloseChannelShouldPanic(Channel *ch) {
+ const size_t data = 5;
+ std::thread send_thread{[&]() {
+ size_t i = data;
+ EXPECT_EQ(ch->Send(&i), true); // should not block
+ }};
+
+ std::thread recv_thread{[&]() {
+ size_t i;
+ EXPECT_EQ(ch->Receive(&i), true); // should not block
+ EXPECT_EQ(i, data);
+ }};
+
+ send_thread.join();
+ recv_thread.join();
+
+ // After closing send should return false. Receive should
+ // also return false as there is no data in queue.
CloseChannel(ch);
- EXPECT_EQ(ch->Send(&i), false); // should panic
+ send_thread = std::thread{[&]() {
+ size_t i = data;
+ EXPECT_EQ(ch->Send(&i), false); // should return false
+ }};
+ recv_thread = std::thread{[&]() {
+ size_t i;
+ // should return false because channel is closed and queue is empty
+ EXPECT_EQ(ch->Receive(&i), false);
+ }};
+
+ send_thread.join();
+ recv_thread.join();
+}
+
+TEST(Channel, SendReceiveClosedBufferedChannelPanics) {
+ size_t buffer_size = 10;
+ auto ch = MakeChannel(buffer_size);
+ SendReceiveWithACloseChannelShouldPanic(ch);
+ delete ch;
+}
+
+TEST(Channel, SendReceiveClosedUnBufferedChannelPanics) {
+ auto ch = MakeChannel(0);
+ SendReceiveWithACloseChannelShouldPanic(ch);
delete ch;
}
@@ -94,9 +157,7 @@ TEST(Channel, ReceiveFromBufferedChannelReturnResidualValuesTest) {
for (size_t i = 0; i < buffer_size; ++i) {
EXPECT_EQ(ch->Receive(&out),
- false); // after receiving residual values, return zeros.
- // Note: we cannot check EXPECT_EQ(out, 0), because C++ doesn't
- // define zero values like Go does.
+ false); // receiving on closed channel should return false
}
delete ch;
}
@@ -115,7 +176,7 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
sum += i;
}
});
- std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.5 sec
+ std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec
EXPECT_EQ(sum, 45U);
CloseChannel(ch);
@@ -123,31 +184,17 @@ TEST(Channel, ConcurrentSendNonConcurrentReceiveWithSufficientBufferSize) {
delete ch;
}
-TEST(Channel, SimpleUnbufferedChannelTest) {
+TEST(Channel, RecevingOrderEqualToSendingOrderWithUnBufferedChannel) {
auto ch = MakeChannel(0);
- unsigned sum_send = 0;
- std::thread t([&]() {
- for (int i = 0; i < 5; i++) {
- EXPECT_EQ(ch->Send(&i), true);
- sum_send += i;
- }
- });
- for (int i = 0; i < 5; i++) {
- int recv;
- EXPECT_EQ(ch->Receive(&recv), true);
- EXPECT_EQ(recv, i);
- }
+ RecevingOrderEqualToSendingOrder(ch);
+}
- CloseChannel(ch);
- t.join();
- EXPECT_EQ(sum_send, 10U);
- delete ch;
+TEST(Channel, RecevingOrderEqualToSendingOrderWithBufferedChannel) {
+ auto ch = MakeChannel(10);
+ RecevingOrderEqualToSendingOrder(ch);
}
-// This tests that closing a buffered channel also unblocks
-// any receivers waiting on the channel
-TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
- auto ch = MakeChannel(1);
+void ChannelCloseUnblocksReceiversTest(Channel *ch) {
size_t num_threads = 5;
std::thread t[num_threads];
bool thread_ended[num_threads];
@@ -158,15 +205,14 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
t[i] = std::thread(
[&](bool *p) {
int data;
- // All reads should return false
EXPECT_EQ(ch->Receive(&data), false);
*p = true;
},
&thread_ended[i]);
}
- std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
+ std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec
- // Verify that all threads are blocked
+ // Verify that all the threads are blocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], false);
}
@@ -175,7 +221,7 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
// This should unblock all receivers
CloseChannel(ch);
- std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
+ std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait 0.1 sec
// Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) {
@@ -183,13 +229,12 @@ TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
}
for (size_t i = 0; i < num_threads; i++) t[i].join();
- delete ch;
}
-// This tests that closing a buffered channel also unblocks
-// any senders waiting for channel to have write space
-TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
- auto ch = MakeChannel(1);
+void ChannelCloseUnblocksSendersTest(Channel *ch) {
+ using paddle::framework::details::Buffered;
+ using paddle::framework::details::UnBuffered;
+
size_t num_threads = 5;
std::thread t[num_threads];
bool thread_ended[num_threads];
@@ -209,34 +254,56 @@ TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
}
std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
- // Verify that atleast 4 threads are blocked
- int ct = 0;
- for (size_t i = 0; i < num_threads; i++) {
- if (thread_ended[i] == false) ct++;
+ if (dynamic_cast *>(ch)) {
+ // If ch is Buffered, atleast 4 threads must be blocked.
+ int ct = 0;
+ for (size_t i = 0; i < num_threads; i++) {
+ if (!thread_ended[i]) ct++;
+ }
+ EXPECT_GE(ct, 4);
+ } else {
+ // If ch is UnBuffered, all the threads should be blocked.
+ for (size_t i = 0; i < num_threads; i++) {
+ EXPECT_EQ(thread_ended[i], false);
+ }
}
- // Atleast 4 threads must be blocked
- EXPECT_GE(ct, 4);
-
// Explicitly close the thread
// This should unblock all senders
CloseChannel(ch);
- std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
+ std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
// Verify that all threads got unblocked
for (size_t i = 0; i < num_threads; i++) {
EXPECT_EQ(thread_ended[i], true);
}
- // Verify that only 1 send was successful
- ct = 0;
- for (size_t i = 0; i < num_threads; i++) {
- if (send_success[i]) ct++;
+ if (dynamic_cast *>(ch)) {
+ // Verify that only 1 send was successful
+ int ct = 0;
+ for (size_t i = 0; i < num_threads; i++) {
+ if (send_success[i]) ct++;
+ }
+ // Only 1 send must be successful
+ EXPECT_EQ(ct, 1);
}
- // Only 1 send must be successful
- EXPECT_EQ(ct, 1);
for (size_t i = 0; i < num_threads; i++) t[i].join();
+}
+
+// This tests that closing a buffered channel also unblocks
+// any receivers waiting on the channel
+TEST(Channel, BufferedChannelCloseUnblocksReceiversTest) {
+ auto ch = MakeChannel(1);
+ ChannelCloseUnblocksReceiversTest(ch);
+ delete ch;
+}
+
+// This tests that closing a buffered channel also unblocks
+// any senders waiting for channel to have write space
+TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
+ auto ch = MakeChannel(1);
+ ChannelCloseUnblocksSendersTest(ch);
delete ch;
}
@@ -244,40 +311,7 @@ TEST(Channel, BufferedChannelCloseUnblocksSendersTest) {
// unblocks any receivers waiting for senders
TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) {
auto ch = MakeChannel(0);
- size_t num_threads = 5;
- std::thread t[num_threads];
- bool thread_ended[num_threads];
-
- // Launches threads that try to read and are blocked becausew of no writers
- for (size_t i = 0; i < num_threads; i++) {
- thread_ended[i] = false;
- t[i] = std::thread(
- [&](bool *p) {
- int data;
- EXPECT_EQ(ch->Receive(&data), false);
- *p = true;
- },
- &thread_ended[i]);
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
-
- // Verify that all the threads are blocked
- for (size_t i = 0; i < num_threads; i++) {
- EXPECT_EQ(thread_ended[i], false);
- }
-
- // Explicitly close the thread
- // This should unblock all receivers
- CloseChannel(ch);
-
- std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
-
- // Verify that all threads got unblocked
- for (size_t i = 0; i < num_threads; i++) {
- EXPECT_EQ(thread_ended[i], true);
- }
-
- for (size_t i = 0; i < num_threads; i++) t[i].join();
+ ChannelCloseUnblocksReceiversTest(ch);
delete ch;
}
@@ -285,40 +319,7 @@ TEST(Channel, UnbufferedChannelCloseUnblocksReceiversTest) {
// unblocks any senders waiting for senders
TEST(Channel, UnbufferedChannelCloseUnblocksSendersTest) {
auto ch = MakeChannel(0);
- size_t num_threads = 5;
- std::thread t[num_threads];
- bool thread_ended[num_threads];
-
- // Launches threads that try to read and are blocked becausew of no writers
- for (size_t i = 0; i < num_threads; i++) {
- thread_ended[i] = false;
- t[i] = std::thread(
- [&](bool *p) {
- int data = 10;
- EXPECT_EQ(ch->Send(&data), false);
- *p = true;
- },
- &thread_ended[i]);
- }
- std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
-
- // Verify that all the threads are blocked
- for (size_t i = 0; i < num_threads; i++) {
- EXPECT_EQ(thread_ended[i], false);
- }
-
- // Explicitly close the thread
- // This should unblock all receivers
- CloseChannel(ch);
-
- std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
-
- // Verify that all threads got unblocked
- for (size_t i = 0; i < num_threads; i++) {
- EXPECT_EQ(thread_ended[i], true);
- }
-
- for (size_t i = 0; i < num_threads; i++) t[i].join();
+ ChannelCloseUnblocksReceiversTest(ch);
delete ch;
}
@@ -381,3 +382,129 @@ TEST(Channel, UnbufferedMoreReceiveLessSendTest) {
EXPECT_EQ(sum_receive, 28U);
delete ch;
}
+
+// This tests that destroying a channel unblocks
+// any senders waiting for channel to have write space
+void ChannelDestroyUnblockSenders(Channel *ch) {
+ size_t num_threads = 5;
+ std::thread t[num_threads];
+ bool thread_ended[num_threads];
+ bool send_success[num_threads];
+
+ // Launches threads that try to write and are blocked because of no readers
+ for (size_t i = 0; i < num_threads; i++) {
+ thread_ended[i] = false;
+ send_success[i] = false;
+ t[i] = std::thread(
+ [&](bool *ended, bool *success) {
+ int data = 10;
+ *success = ch->Send(&data);
+ *ended = true;
+ },
+ &thread_ended[i], &send_success[i]);
+ }
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(500)); // wait 0.5 sec
+ bool is_buffered_channel = false;
+ if (dynamic_cast *>(ch)) is_buffered_channel = true;
+
+ if (is_buffered_channel) {
+ // If channel is buffered, verify that atleast 4 threads are blocked
+ int ct = 0;
+ for (size_t i = 0; i < num_threads; i++) {
+ if (thread_ended[i] == false) ct++;
+ }
+ // Atleast 4 threads must be blocked
+ EXPECT_GE(ct, 4);
+ } else {
+ // Verify that all the threads are blocked
+ for (size_t i = 0; i < num_threads; i++) {
+ EXPECT_EQ(thread_ended[i], false);
+ }
+ }
+ // Explicitly destroy the channel
+ delete ch;
+ std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
+
+ // Verify that all threads got unblocked
+ for (size_t i = 0; i < num_threads; i++) {
+ EXPECT_EQ(thread_ended[i], true);
+ }
+
+ // Count number of successfuld sends
+ int ct = 0;
+ for (size_t i = 0; i < num_threads; i++) {
+ if (send_success[i]) ct++;
+ }
+
+ if (is_buffered_channel) {
+ // Only 1 send must be successful
+ EXPECT_EQ(ct, 1);
+ } else {
+ // In unbuffered channel, no send should be successful
+ EXPECT_EQ(ct, 0);
+ }
+
+ // Join all threads
+ for (size_t i = 0; i < num_threads; i++) t[i].join();
+}
+
+// This tests that destroying a channel also unblocks
+// any receivers waiting on the channel
+void ChannelDestroyUnblockReceivers(Channel *ch) {
+ size_t num_threads = 5;
+ std::thread t[num_threads];
+ bool thread_ended[num_threads];
+
+ // Launches threads that try to read and are blocked because of no writers
+ for (size_t i = 0; i < num_threads; i++) {
+ thread_ended[i] = false;
+ t[i] = std::thread(
+ [&](bool *p) {
+ int data;
+ // All reads should return false
+ EXPECT_EQ(ch->Receive(&data), false);
+ *p = true;
+ },
+ &thread_ended[i]);
+ }
+ std::this_thread::sleep_for(std::chrono::milliseconds(100)); // wait
+
+ // Verify that all threads are blocked
+ for (size_t i = 0; i < num_threads; i++) {
+ EXPECT_EQ(thread_ended[i], false);
+ }
+ // delete the channel
+ delete ch;
+ std::this_thread::sleep_for(std::chrono::milliseconds(200)); // wait
+ // Verify that all threads got unblocked
+ for (size_t i = 0; i < num_threads; i++) {
+ EXPECT_EQ(thread_ended[i], true);
+ }
+
+ for (size_t i = 0; i < num_threads; i++) t[i].join();
+}
+
+TEST(Channel, BufferedChannelDestroyUnblocksReceiversTest) {
+ size_t buffer_size = 1;
+ auto ch = MakeChannel(buffer_size);
+ ChannelDestroyUnblockReceivers(ch);
+}
+
+TEST(Channel, BufferedChannelDestroyUnblocksSendersTest) {
+ size_t buffer_size = 1;
+ auto ch = MakeChannel(buffer_size);
+ ChannelDestroyUnblockSenders(ch);
+}
+
+// This tests that destroying an unbuffered channel also unblocks
+// unblocks any receivers waiting for senders
+TEST(Channel, UnbufferedChannelDestroyUnblocksReceiversTest) {
+ auto ch = MakeChannel(0);
+ ChannelDestroyUnblockReceivers(ch);
+}
+
+TEST(Channel, UnbufferedChannelDestroyUnblocksSendersTest) {
+ auto ch = MakeChannel(0);
+ ChannelDestroyUnblockSenders(ch);
+}
diff --git a/paddle/framework/details/buffered_channel.h b/paddle/framework/details/buffered_channel.h
index 00b63da4da7844b41168c03f55e2faa84ff44154..227a4e4811f95441158150396b5b882815fd7844 100644
--- a/paddle/framework/details/buffered_channel.h
+++ b/paddle/framework/details/buffered_channel.h
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -25,6 +25,14 @@ namespace paddle {
namespace framework {
namespace details {
+// Four of the properties of Buffered Channel:
+// - A send to a full channel blocks temporarily until a receive from the
+// channel or the channel is closed.
+// - A receive from an empty channel blocks temporarily until a send to the
+// channel or the channel is closed.
+// - A send to a closed channel returns false immediately.
+// - A receive from a closed channel returns false immediately.
+
template
class Buffered : public paddle::framework::Channel {
friend Channel* paddle::framework::MakeChannel(size_t);
@@ -42,8 +50,11 @@ class Buffered : public paddle::framework::Channel {
std::mutex mu_;
std::condition_variable empty_cond_var_;
std::condition_variable full_cond_var_;
+ std::condition_variable destructor_cond_var_;
std::deque channel_;
std::atomic closed_{false};
+ std::atomic send_ctr{0};
+ std::atomic recv_ctr{0};
Buffered(size_t cap) : cap_(cap), closed_(false) {
PADDLE_ENFORCE_GT(cap, 0);
@@ -58,6 +69,7 @@ bool Buffered::Send(T* item) {
if (closed_) {
return ret;
}
+ send_ctr++;
std::unique_lock lock(mu_);
full_cond_var_.wait(lock,
[this]() { return channel_.size() < cap_ || closed_; });
@@ -67,20 +79,30 @@ bool Buffered::Send(T* item) {
empty_cond_var_.notify_one();
ret = true;
}
+ send_ctr--;
+ destructor_cond_var_.notify_one();
return ret;
}
template
bool Buffered::Receive(T* item) {
+ bool ret = false;
+ // Once the channel has been closed and all data has been consumed,
+ // just return false. Don't even try acquiring the mutex.
+ if (closed_ && channel_.empty()) {
+ return false;
+ }
+ recv_ctr++;
std::unique_lock lock(mu_);
empty_cond_var_.wait(lock, [this]() { return !channel_.empty() || closed_; });
- bool ret = false;
if (!channel_.empty()) {
*item = std::move(channel_.front());
channel_.pop_front();
full_cond_var_.notify_one();
ret = true;
}
+ recv_ctr--;
+ destructor_cond_var_.notify_one();
return ret;
}
@@ -100,6 +122,12 @@ Buffered::~Buffered() {
closed_ = true;
channel_.clear();
NotifyAllParticipants(&lock);
+
+ // The destructor must wait for all readers and writers to complete their task
+ // The channel has been closed, so we will not accept new readers and writers
+ lock.lock();
+ destructor_cond_var_.wait(
+ lock, [this]() { return send_ctr == 0 && recv_ctr == 0; });
}
template
diff --git a/paddle/framework/details/cow_ptr.h b/paddle/framework/details/cow_ptr.h
index 7e308ffb5a49876aa2c1833b3b7e2a2c7eb137aa..69bcea625288eba897e761a1d634f19c41dc0f79 100644
--- a/paddle/framework/details/cow_ptr.h
+++ b/paddle/framework/details/cow_ptr.h
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/paddle/framework/details/cow_ptr_test.cc b/paddle/framework/details/cow_ptr_test.cc
index 936954a2333e7e5d2a932abad641279db9ef7b9f..1f4a12bca0dcab2d146cc62cd7ce1c2d7abcddf9 100644
--- a/paddle/framework/details/cow_ptr_test.cc
+++ b/paddle/framework/details/cow_ptr_test.cc
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 6d50e820b2b625f932768d2ca671d999071f1ca6..31a40bcbcb3905f01aebefe89526f3cfba8cb8c7 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -1,4 +1,4 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/paddle/framework/details/unbuffered_channel.h b/paddle/framework/details/unbuffered_channel.h
index 815cebad2d8c08aa31bb566bc6c51250870383d8..6b5c2196cb2991051c48f7da8397d2f479ca4c58 100644
--- a/paddle/framework/details/unbuffered_channel.h
+++ b/paddle/framework/details/unbuffered_channel.h
@@ -23,6 +23,13 @@ namespace paddle {
namespace framework {
namespace details {
+// Four of the properties of UnBuffered Channel:
+// - A send to a channel blocks temporarily until a receive from the
+// channel or the channel is closed.
+// - A receive from a channel blocks temporarily until a send to the
+// channel or the channel is closed.
+// - A send to a closed channel returns false immediately.
+// - A receive from a closed channel returns false immediately.
template
class UnBuffered : public paddle::framework::Channel {
friend Channel* paddle::framework::MakeChannel(size_t);
@@ -45,9 +52,11 @@ class UnBuffered : public paddle::framework::Channel {
// A transaction occurs only when both are true
std::atomic reader_found_{false}, writer_found_{false};
std::condition_variable cv_channel_;
- std::condition_variable_any cv_reader_, cv_writer_;
+ std::condition_variable_any cv_reader_, cv_writer_, cv_destructor_;
T* item{nullptr};
std::atomic closed_{false};
+ std::atomic send_ctr{0};
+ std::atomic recv_ctr{0};
UnBuffered() : closed_(false) {}
@@ -62,6 +71,7 @@ bool UnBuffered::Send(T* data) {
if (closed_) {
return ret;
}
+ send_ctr++;
// Prevent other writers from entering
std::unique_lock writer_lock(mu_write_);
writer_found_ = true;
@@ -81,6 +91,8 @@ bool UnBuffered::Send(T* data) {
ret = true;
}
writer_found_ = false;
+ send_ctr--;
+ cv_destructor_.notify_one();
return ret;
}
@@ -88,6 +100,12 @@ bool UnBuffered::Send(T* data) {
// data that was sent by a writer is read from a reader.
template
bool UnBuffered::Receive(T* data) {
+ bool ret = false;
+ // If channel is closed, we don't even want any reader to enter.
+ // Unlike a buffered channel, an unbuffered channel does not allow
+ // readers to read after closing because there is no buffer to be consumed.
+ if (closed_) return ret;
+ recv_ctr++;
// Prevent other readers from entering
std::unique_lock read_lock{mu_read_};
reader_found_ = true;
@@ -96,7 +114,6 @@ bool UnBuffered::Receive(T* data) {
cv_reader_.wait(cv_lock,
[this]() { return writer_found_ == true || closed_; });
cv_writer_.notify_one();
- bool ret = false;
if (!closed_) {
std::unique_lock lock_ch{mu_ch_};
// Reader should wait for the writer to first write its data
@@ -110,6 +127,8 @@ bool UnBuffered::Receive(T* data) {
cv_channel_.notify_one();
}
reader_found_ = false;
+ recv_ctr--;
+ cv_destructor_.notify_one();
return ret;
}
@@ -135,6 +154,9 @@ UnBuffered::~UnBuffered() {
item = nullptr;
closed_ = true;
NotifyAllParticipants(&lock);
+ lock.lock();
+ cv_destructor_.wait(lock,
+ [this]() { return send_ctr == 0 && recv_ctr == 0; });
}
// This function notifies all the readers, writers and
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index ea4028750248ec47f5094a67f736fb217216af6d..b51afe499bbc0e6b727aeeb4334f56e400ea81a5 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -125,11 +125,10 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
// restore attrs_
for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
+ // The sub_block referred to by the BLOCK attr hasn't been added
+ // to ProgramDesc class yet, we skip setting BLOCK attr here.
if (attr.type() != proto::AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
- } else {
- auto bid = attr.block_idx();
- attrs_[attr_name] = prog->MutableBlock(bid);
}
}
this->block_ = block;
diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc
index 15ea4035c6e6193105b621210a900e74d1466941..0e937dda4e185590648962a6d4f827eea21eb620 100644
--- a/paddle/framework/program_desc.cc
+++ b/paddle/framework/program_desc.cc
@@ -43,11 +43,20 @@ ProgramDesc::ProgramDesc() {
ProgramDesc::ProgramDesc(const ProgramDesc &o) {
desc_ = o.desc_;
-
for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
}
+ for (auto &block : blocks_) {
+ for (auto *op : block->AllOps()) {
+ for (const auto &attr : op->Proto()->attrs()) {
+ if (attr.type() == proto::AttrType::BLOCK) {
+ size_t blk_idx = attr.block_idx();
+ op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx));
+ }
+ }
+ }
+ }
}
ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
@@ -55,6 +64,16 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDesc(this, &block_desc));
}
+ for (auto &block : blocks_) {
+ for (auto *op : block->AllOps()) {
+ for (const auto &attr : op->Proto()->attrs()) {
+ if (attr.type() == proto::AttrType::BLOCK) {
+ size_t blk_idx = attr.block_idx();
+ op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx));
+ }
+ }
+ }
+ }
}
ProgramDesc::ProgramDesc(const std::string &binary_str) {
diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc
index bff8e0bceaca9749101b2c45edddba526d565624..ddd6b993d40f72cba919fad95318f70409c98bca 100644
--- a/paddle/framework/prune.cc
+++ b/paddle/framework/prune.cc
@@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) {
return false;
}
-void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
- int block_id) {
- // TODO(tonyyang-svail):
- // - will change to use multiple blocks for RNN op and Cond Op
+int GetSubBlockIndex(const proto::OpDesc& op_desc) {
+ for (auto& attr : op_desc.attrs()) {
+ if (attr.type() == proto::AttrType::BLOCK) {
+ PADDLE_ENFORCE(attr.has_block_idx());
+ return attr.block_idx();
+ }
+ }
+ return -1;
+}
+
+bool HasSubBlock(const proto::OpDesc& op_desc) {
+ return GetSubBlockIndex(op_desc) > 0;
+}
+// block_id is the idx of the current block in the input desc
+// parent_block_id is the idx of the parent of the current block
+// in the output desc, -1 means the current block is global block
+// dependent_vars is passed recursively from the parent block to
+// the child block to help pruning
+void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
+ int block_id, int parent_block_id,
+ std::set& dependent_vars) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();
@@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
expect_fetch = (op_desc.type() == kFetchOpType);
}
- std::set dependent_vars;
std::vector should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
-
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
@@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
dependent_vars.insert(argu);
}
}
-
should_run.push_back(true);
} else {
should_run.push_back(false);
@@ -95,45 +109,81 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
- *output = input;
- auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
+ // copy the current block from input to output
+ auto* block_field = output->mutable_blocks();
+ *block_field->Add() = input.blocks(block_id);
+
+ int output_block_id = output->blocks_size() - 1;
+ auto* output_block = output->mutable_blocks(output_block_id);
+ output_block->set_idx(output_block_id);
+ output_block->set_parent_idx(parent_block_id);
+
+ auto* op_field = output_block->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
- *op_field->Add() = input.blocks(block_id).ops(i);
+ auto* op = op_field->Add();
+ *op = input.blocks(block_id).ops(i);
+ if (HasSubBlock(*op)) {
+ // create sub_block_dependent_vars here to help prune the sub block
+ std::set sub_block_dependent_vars;
+ for (auto& var : op->inputs()) {
+ for (auto& argu : var.arguments()) {
+ sub_block_dependent_vars.insert(argu);
+ }
+ }
+ for (auto& var : op->outputs()) {
+ for (auto& argu : var.arguments()) {
+ sub_block_dependent_vars.insert(argu);
+ }
+ }
+ // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
+ // output_block_id is the idx of the current block in the output desc
+ prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
+ sub_block_dependent_vars);
+ }
}
}
// remove the VarDescs in BlockDesc that are not referenced in
// the pruned OpDescs
std::unordered_map var_map;
- auto* var_field = output->mutable_blocks(block_id)->mutable_vars();
+ auto* var_field = output->mutable_blocks(output_block_id)->mutable_vars();
for (const auto& var : *var_field) {
var_map[var.name()] = var;
}
- var_field->Clear();
+ std::set var_names;
for (const auto& op : *op_field) {
- // add VarDescs of all input arguments for each OpDesc
auto& input_field = op.inputs();
for (auto& input_var : input_field) {
for (auto& arg : input_var.arguments()) {
- *var_field->Add() = var_map[arg];
+ if (var_map.count(arg) != 0) {
+ var_names.insert(arg);
+ }
}
}
- // add VarDescs of all output arguments for each OpDesc
auto& output_field = op.outputs();
for (auto& output_var : output_field) {
for (auto& arg : output_var.arguments()) {
- *var_field->Add() = var_map[arg];
+ if (var_map.count(arg) != 0) {
+ var_names.insert(arg);
+ }
}
}
}
+
+ var_field->Clear();
+ for (const auto& name : var_names) {
+ *var_field->Add() = var_map[name];
+ }
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
- prune_impl(input, output, 0);
+ std::set dependent_vars;
+ output->clear_blocks();
+ prune_impl(input, output, 0, -1, dependent_vars);
}
void inference_optimize_impl(const proto::ProgramDesc& input,
diff --git a/paddle/inference/io.cc b/paddle/inference/io.cc
index 1ed14b69c83a7a0fb5a55db9c179df133407440c..784e87970f77857e7f3182df904dc0133c44d6c9 100644
--- a/paddle/inference/io.cc
+++ b/paddle/inference/io.cc
@@ -21,6 +21,17 @@ limitations under the License. */
namespace paddle {
namespace inference {
+void ReadBinaryFile(const std::string& filename, std::string& contents) {
+ VLOG(3) << "loading model from " << filename;
+ std::ifstream inputfs(filename, std::ios::in | std::ios::binary);
+ inputfs.seekg(0, std::ios::end);
+ contents.clear();
+ contents.resize(inputfs.tellg());
+ inputfs.seekg(0, std::ios::beg);
+ inputfs.read(&contents[0], contents.size());
+ inputfs.close();
+}
+
bool IsParameter(const framework::VarDesc* var,
const framework::ProgramDesc& main_program) {
if (var->Persistable()) {
@@ -44,12 +55,15 @@ bool IsParameter(const framework::VarDesc* var,
void LoadPersistables(framework::Executor& executor,
framework::Scope& scope,
+ const framework::ProgramDesc& main_program,
const std::string& dirname,
- const framework::ProgramDesc& main_program) {
+ const std::string& param_filename) {
const framework::BlockDesc& global_block = main_program.Block(0);
framework::ProgramDesc* load_program = new framework::ProgramDesc();
framework::BlockDesc* load_block = load_program->MutableBlock(0);
+ std::vector paramlist;
+
for (auto* var : global_block.AllVars()) {
if (IsParameter(var, main_program)) {
VLOG(3) << "parameter's name: " << var->Name();
@@ -61,15 +75,33 @@ void LoadPersistables(framework::Executor& executor,
new_var->SetLoDLevel(var->GetLoDLevel());
new_var->SetPersistable(true);
- // append_op
- framework::OpDesc* op = load_block->AppendOp();
- op->SetType("load");
- op->SetOutput("Out", {new_var->Name()});
- op->SetAttr("file_path", {dirname + "/" + new_var->Name()});
- op->CheckAttrs();
+ if (!param_filename.empty()) {
+ paramlist.push_back(new_var->Name());
+ } else {
+ // append_op
+ framework::OpDesc* op = load_block->AppendOp();
+ op->SetType("load");
+ op->SetOutput("Out", {new_var->Name()});
+ op->SetAttr("file_path", {dirname + "/" + new_var->Name()});
+ op->CheckAttrs();
+ }
}
}
+
+ if (!param_filename.empty()) {
+ // sort paramlist to have consistent ordering
+ std::sort(paramlist.begin(), paramlist.end());
+ // append just the load_combine op
+ framework::OpDesc* op = load_block->AppendOp();
+ op->SetType("load_combine");
+ op->SetOutput("Out", paramlist);
+ op->SetAttr("file_path", {param_filename});
+ op->CheckAttrs();
+ }
+
executor.Run(*load_program, &scope, 0, true, true);
+
+ VLOG(3) << "Ran loading successfully";
delete load_program;
}
@@ -77,20 +109,29 @@ std::unique_ptr Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname) {
std::string model_filename = dirname + "/__model__";
- LOG(INFO) << "loading model from " << model_filename;
- std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
- inputfs.seekg(0, std::ios::end);
- program_desc_str.resize(inputfs.tellg());
- inputfs.seekg(0, std::ios::beg);
- LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
- inputfs.read(&program_desc_str[0], program_desc_str.size());
- inputfs.close();
+ ReadBinaryFile(model_filename, program_desc_str);
+
+ std::unique_ptr main_program(
+ new framework::ProgramDesc(program_desc_str));
+
+ LoadPersistables(executor, scope, *main_program, dirname, "");
+ return main_program;
+}
+
+std::unique_ptr Load(
+ framework::Executor& executor,
+ framework::Scope& scope,
+ const std::string& prog_filename,
+ const std::string& param_filename) {
+ std::string model_filename = prog_filename;
+ std::string program_desc_str;
+ ReadBinaryFile(model_filename, program_desc_str);
std::unique_ptr main_program(
new framework::ProgramDesc(program_desc_str));
- LoadPersistables(executor, scope, dirname, *main_program);
+ LoadPersistables(executor, scope, *main_program, "", param_filename);
return main_program;
}
diff --git a/paddle/inference/io.h b/paddle/inference/io.h
index 962b6c4e20d30de3cc28eae1c8c5c33b3ab5f6ac..a7d7c499690620740d8627e7f5085d728d67f7c3 100644
--- a/paddle/inference/io.h
+++ b/paddle/inference/io.h
@@ -26,12 +26,18 @@ namespace inference {
void LoadPersistables(framework::Executor& executor,
framework::Scope& scope,
+ const framework::ProgramDesc& main_program,
const std::string& dirname,
- const framework::ProgramDesc& main_program);
+ const std::string& param_filename);
std::unique_ptr Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname);
+std::unique_ptr Load(framework::Executor& executor,
+ framework::Scope& scope,
+ const std::string& prog_filename,
+ const std::string& param_filename);
+
} // namespace inference
} // namespace paddle
diff --git a/paddle/inference/tests/book/CMakeLists.txt b/paddle/inference/tests/book/CMakeLists.txt
index 63afeb18aebdf446c01cd4fdac13d238467801e4..0a96829bdd20f5dcb0c3fed501d27c27f2f73b17 100644
--- a/paddle/inference/tests/book/CMakeLists.txt
+++ b/paddle/inference/tests/book/CMakeLists.txt
@@ -27,3 +27,4 @@ endfunction(inference_test)
inference_test(recognize_digits ARGS mlp)
inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles)
+inference_test(rnn_encoder_decoder)
diff --git a/paddle/inference/tests/book/test_helper.h b/paddle/inference/tests/book/test_helper.h
index 32db643fca2b026b674ea0b1ecd9aad5224e9e68..3e66ced94fe6360f0be948a6838cc37ff2f65eed 100644
--- a/paddle/inference/tests/book/test_helper.h
+++ b/paddle/inference/tests/book/test_helper.h
@@ -67,17 +67,28 @@ void CheckError(paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
-template
+template
void TestInference(const std::string& dirname,
const std::vector& cpu_feeds,
std::vector& cpu_fetchs) {
- // 1. Define place, executor and scope
+ // 1. Define place, executor, scope and inference_program
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
+ std::unique_ptr inference_program;
// 2. Initialize the inference_program and load all parameters from file
- auto inference_program = paddle::inference::Load(executor, *scope, dirname);
+ if (IsCombined) {
+ // Hard-coding the names for combined params case
+ std::string prog_filename = "__model_combined__";
+ std::string param_filename = "__params_combined__";
+ inference_program = paddle::inference::Load(executor,
+ *scope,
+ dirname + "/" + prog_filename,
+ dirname + "/" + param_filename);
+ } else {
+ inference_program = paddle::inference::Load(executor, *scope, dirname);
+ }
// 3. Get the feed_target_names and fetch_target_names
const std::vector& feed_target_names =
diff --git a/paddle/inference/tests/book/test_inference_recognize_digits.cc b/paddle/inference/tests/book/test_inference_recognize_digits.cc
index 48f887e6bc680087af4cce74b5c5422a4eba3726..3a48db7fe08205a3e078592651c739f77d5bf415 100644
--- a/paddle/inference/tests/book/test_inference_recognize_digits.cc
+++ b/paddle/inference/tests/book/test_inference_recognize_digits.cc
@@ -59,3 +59,45 @@ TEST(inference, recognize_digits) {
CheckError(output1, output2);
#endif
}
+
+TEST(inference, recognize_digits_combine) {
+ if (FLAGS_dirname.empty()) {
+ LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
+ }
+
+ LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
+ std::string dirname = FLAGS_dirname;
+
+ // 0. Call `paddle::framework::InitDevices()` initialize all the devices
+ // In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+
+ paddle::framework::LoDTensor input;
+ // Use normilized image pixels as input data,
+ // which should be in the range [-1.0, 1.0].
+ SetupTensor(
+ input, {1, 28, 28}, static_cast(-1), static_cast(1));
+ std::vector cpu_feeds;
+ cpu_feeds.push_back(&input);
+
+ paddle::framework::LoDTensor output1;
+ std::vector cpu_fetchs1;
+ cpu_fetchs1.push_back(&output1);
+
+ // Run inference on CPU
+ TestInference(
+ dirname, cpu_feeds, cpu_fetchs1);
+ LOG(INFO) << output1.dims();
+
+#ifdef PADDLE_WITH_CUDA
+ paddle::framework::LoDTensor output2;
+ std::vector cpu_fetchs2;
+ cpu_fetchs2.push_back(&output2);
+
+ // Run inference on CUDA GPU
+ TestInference(
+ dirname, cpu_feeds, cpu_fetchs2);
+ LOG(INFO) << output2.dims();
+
+ CheckError(output1, output2);
+#endif
+}
diff --git a/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc
new file mode 100644
index 0000000000000000000000000000000000000000..9bfc0407b7f2732a14e7ac0f319a3d39b9e641bc
--- /dev/null
+++ b/paddle/inference/tests/book/test_inference_rnn_encoder_decoder.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include "gflags/gflags.h"
+#include "test_helper.h"
+
+DEFINE_string(dirname, "", "Directory of the inference model.");
+
+TEST(inference, rnn_encoder_decoder) {
+ if (FLAGS_dirname.empty()) {
+ LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
+ }
+
+ LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
+ std::string dirname = FLAGS_dirname;
+
+ // 0. Call `paddle::framework::InitDevices()` initialize all the devices
+ // In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+
+ paddle::framework::LoDTensor word_data, trg_word;
+ paddle::framework::LoD lod{{0, 4, 10}};
+
+ SetupLoDTensor(
+ word_data, lod, static_cast(0), static_cast(1));
+ SetupLoDTensor(
+ trg_word, lod, static_cast(0), static_cast(1));
+
+ std::vector cpu_feeds;
+ cpu_feeds.push_back(&word_data);
+ cpu_feeds.push_back(&trg_word);
+
+ paddle::framework::LoDTensor output1;
+ std::vector cpu_fetchs1;
+ cpu_fetchs1.push_back(&output1);
+
+ // Run inference on CPU
+ TestInference(
+ dirname, cpu_feeds, cpu_fetchs1);
+ LOG(INFO) << output1.lod();
+ LOG(INFO) << output1.dims();
+
+#ifdef PADDLE_WITH_CUDA
+ paddle::framework::LoDTensor output2;
+ std::vector cpu_fetchs2;
+ cpu_fetchs2.push_back(&output2);
+
+ // Run inference on CUDA GPU
+ TestInference(
+ dirname, cpu_feeds, cpu_fetchs2);
+ LOG(INFO) << output2.lod();
+ LOG(INFO) << output2.dims();
+
+ CheckError(output1, output2);
+#endif
+}
diff --git a/paddle/operators/compare_op.h b/paddle/operators/compare_op.h
index b275fd75b3512343825170fc38565dd27f7f1c75..79b8c6f59c7ad3d77aa969f6b4f36f8050cfe823 100644
--- a/paddle/operators/compare_op.h
+++ b/paddle/operators/compare_op.h
@@ -62,7 +62,7 @@ class CompareOpKernel
z->mutable_data(context.GetPlace());
int axis = context.Attr("axis");
ElementwiseComputeEx(context, x, y, axis,
- z);
+ Functor(), z);
}
};
diff --git a/paddle/operators/elementwise_add_op.h b/paddle/operators/elementwise_add_op.h
index c32288d6984f126f2374a13973541f4f663b25a4..c24f97a85092ff14e8211ca8bc4bb9b155510a2c 100644
--- a/paddle/operators/elementwise_add_op.h
+++ b/paddle/operators/elementwise_add_op.h
@@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ AddFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_div_op.h b/paddle/operators/elementwise_div_op.h
index 07ebade31ff5b3d5c89156e28ff5fa0670a9a842..dc863cc598ec6015067f166b1544a5d20223662a 100644
--- a/paddle/operators/elementwise_div_op.h
+++ b/paddle/operators/elementwise_div_op.h
@@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ DivFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_max_op.h b/paddle/operators/elementwise_max_op.h
index 717e45ab31db9b9a6629fb33e17654dbf986d8c5..67efe4e1511e054d54f91b5aa22ce28f222ed20a 100644
--- a/paddle/operators/elementwise_max_op.h
+++ b/paddle/operators/elementwise_max_op.h
@@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ MaxFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_min_op.h b/paddle/operators/elementwise_min_op.h
index 0de9a91c52b0ab82cd62604de318ce68e56b767d..cf11759404d3342b8a1c0080fa09f6cd57e735db 100644
--- a/paddle/operators/elementwise_min_op.h
+++ b/paddle/operators/elementwise_min_op.h
@@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ MinFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h
index ae7a71e0244dfb8ad3e55683ac081f92bc36bea5..773125f5ca54e7b529df47a2823d56a5ad71e50d 100644
--- a/paddle/operators/elementwise_mul_op.h
+++ b/paddle/operators/elementwise_mul_op.h
@@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ MulFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h
index 213fe1f5a818873e8b666464cb112637261c598c..74abf7c4a58788eb0e53025886f10f5a43021a9e 100644
--- a/paddle/operators/elementwise_op_function.h
+++ b/paddle/operators/elementwise_op_function.h
@@ -365,10 +365,10 @@ template
void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
- const framework::Tensor* y, int axis,
+ const framework::Tensor* y, int axis, Functor func,
framework::Tensor* z) {
TransformFunctor functor(
- x, y, z, ctx.template device_context(), Functor());
+ x, y, z, ctx.template device_context(), func);
auto x_dims = x->dims();
auto y_dims = y->dims();
diff --git a/paddle/operators/elementwise_pow_op.h b/paddle/operators/elementwise_pow_op.h
index 874fd3f09f2afaccfbfca75799cc3448f7393b03..0c5dd031ec46ebecaabb701839c0f69c02678eb0 100644
--- a/paddle/operators/elementwise_pow_op.h
+++ b/paddle/operators/elementwise_pow_op.h
@@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ PowFunctor(), z);
}
};
diff --git a/paddle/operators/elementwise_sub_op.h b/paddle/operators/elementwise_sub_op.h
index c2749a8e6ba689233dab4f3c72de10bf01f39fab..6a88c5f6b4c869f8ab5b4fa3b112ffc264be7145 100644
--- a/paddle/operators/elementwise_sub_op.h
+++ b/paddle/operators/elementwise_sub_op.h
@@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel {
auto* z = ctx.Output("Out");
z->mutable_data(ctx.GetPlace());
int axis = ctx.Attr("axis");
- ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, z);
+ ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis,
+ SubFunctor(), z);
}
};
diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc
index 1c6d2ae4d05becaeed34d66cad398cc90f9d3ece..76d5d571c31c0cdec207cd171291da1f58d29b61 100644
--- a/paddle/operators/layer_norm_op.cc
+++ b/paddle/operators/layer_norm_op.cc
@@ -21,13 +21,6 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
-template
-using EigenMatrixMapRowMajor = Eigen::Map<
- Eigen::Matrix>;
-template
-using ConstEigenMatrixMapRowMajor = Eigen::Map<
- const Eigen::Matrix>;
-
class LayerNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -108,7 +101,6 @@ class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Layer Normalization.
-
Layer Norm has been implemented as discussed in the paper:
https://arxiv.org/abs/1607.06450
...
@@ -116,75 +108,6 @@ https://arxiv.org/abs/1607.06450
}
};
-template
-class LayerNormKernel
- : public framework::OpKernel {
- public:
- void Compute(const framework::ExecutionContext &ctx) const override {
- const float epsilon = ctx.Attr("epsilon");
- const auto *scale = ctx.Input("Scale");
- const auto *bias = ctx.Input("Bias");
- const auto *x = ctx.Input("X");
- const auto &x_dims = x->dims();
- const auto begin_norm_axis = ctx.Attr("begin_norm_axis");
-
- auto *output = ctx.Output("Y");
- auto *mean = ctx.Output("Mean");
- auto *var = ctx.Output("Variance");
- output->mutable_data(ctx.GetPlace());
- mean->mutable_data(ctx.GetPlace());
- var->mutable_data(ctx.GetPlace());
-
- auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
- int left = static_cast(matrix_dim[0]);
- int right = static_cast(matrix_dim[1]);
-
- auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right);
-
- auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1);
- auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1);
- auto output_map = EigenMatrixMapRowMajor(output->data(), left, right);
-
- auto squre = [](T ele) { return ele * ele; };
- auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
-
- mean_map = input_map.rowwise().mean();
- var_map = (input_map - mean_map.replicate(1, right))
- .unaryExpr(squre)
- .rowwise()
- .mean()
- .unaryExpr(add_epslion);
-
- auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
- // TODO(zcd): Some thinking about output_map, is it appropriate that
- // `output_map` and `input_map` point to the same memory.
- auto inv_std = var_map.unaryExpr(inv_std_func);
- if (scale && bias) {
- auto scale_map =
- ConstEigenMatrixMapRowMajor(scale->data(), 1, right);
- auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right);
- output_map = (input_map - mean_map.replicate(1, right))
- .cwiseProduct(inv_std.replicate(1, right))
- .cwiseProduct(scale_map.replicate(left, 1)) +
- bias_map.replicate(left, 1);
- } else if (scale) {
- auto scale_map =
- ConstEigenMatrixMapRowMajor(scale->data(), 1, right);
- output_map = (input_map - mean_map.replicate(1, right))
- .cwiseProduct(inv_std.replicate(1, right))
- .cwiseProduct(scale_map.replicate(left, 1));
- } else if (bias) {
- auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right);
- output_map = (input_map - mean_map.replicate(1, right))
- .cwiseProduct(inv_std.replicate(1, right)) +
- bias_map.replicate(left, 1);
- } else {
- output_map = (input_map - mean_map.replicate(1, right))
- .cwiseProduct(inv_std.replicate(1, right));
- }
- }
-};
-
class LayerNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -237,125 +160,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
};
-template
-class LayerNormGradKernel
- : public framework::OpKernel {
- public:
- void Compute(const framework::ExecutionContext &ctx) const override {
- const auto *x = ctx.Input("X");
- const auto *mean = ctx.Input("Mean");
- const auto *var = ctx.Input("Variance");
- const auto *scale = ctx.Input("Scale");
- const auto *d_y = ctx.Input(framework::GradVarName("Y"));
-
- const auto &x_dims = x->dims();
-
- const auto begin_norm_axis = ctx.Attr("begin_norm_axis");
- auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
- int left = static_cast(matrix_dim[0]);
- int right = static_cast(matrix_dim[1]);
-
- // init output
- auto *d_x = ctx.Output(framework::GradVarName("X"));
- auto *d_scale = ctx.Output(framework::GradVarName("Scale"));
- auto *d_bias = ctx.Output(framework::GradVarName("Bias"));
-
- auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right);
- auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right);
- auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1);
- auto var_map = ConstEigenMatrixMapRowMajor(var->data(), left, 1);
-
- if (d_bias) {
- d_bias->mutable_data(ctx.GetPlace());
- auto d_bias_map = EigenMatrixMapRowMajor(d_bias->data(), 1, right);
- d_bias_map = d_y_map.colwise().sum();
- }
- if (d_scale) {
- d_scale->mutable_data(ctx.GetPlace());
- auto d_scale_map =
- EigenMatrixMapRowMajor(d_scale->data(), 1, right);
- auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
- // There are two equation to compute d_scale. One uses "Y" and the other
- // does not use "Y"
- d_scale_map =
- ((x_map - mean_map.replicate(1, right))
- .cwiseProduct(
- var_map.unaryExpr(inv_std_func).replicate(1, right))
- .cwiseProduct(d_y_map))
- .colwise()
- .sum();
- }
-
- if (d_x) {
- d_x->mutable_data(ctx.GetPlace());
- auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right);
- auto triple_product_func = [](T ele) { return ele * ele * ele; };
- auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
- // TODO(zcd): these code can be refined
- if (d_scale) {
- auto scale_map =
- ConstEigenMatrixMapRowMajor(scale->data(), 1, right);
- // dy_dx
- auto dx_end = var_map.unaryExpr(inv_std_func)
- .replicate(1, right)
- .cwiseProduct(d_y_map)
- .cwiseProduct(scale_map.replicate(left, 1));
- // dy_dmean_dx
- auto dx_mean = (T(-1.0) / right) *
- var_map.unaryExpr(inv_std_func)
- .replicate(1, right)
- .cwiseProduct(d_y_map)
- .cwiseProduct(scale_map.replicate(left, 1))
- .rowwise()
- .sum()
- .replicate(1, right);
- // dy_var_dx
- auto dvar_end_part = (x_map - mean_map.replicate(1, right))
- .cwiseProduct(scale_map.replicate(left, 1))
- .cwiseProduct(d_y_map)
- .rowwise()
- .sum();
- auto dvar_end = var_map.unaryExpr(inv_std_func)
- .unaryExpr(triple_product_func)
- .cwiseProduct(dvar_end_part)
- .replicate(1, right);
- auto dx_var =
- (T(-1.0) / right) *
- (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
-
- d_x_map = dx_end + dx_mean + dx_var;
- } else {
- // dy_dx
- auto dx_end = var_map.unaryExpr(inv_std_func)
- .replicate(1, right)
- .cwiseProduct(d_y_map);
- // dy_dmean_dx
- auto dx_mean = (T(-1.0) / right) *
- var_map.unaryExpr(inv_std_func)
- .replicate(1, right)
- .cwiseProduct(d_y_map)
- .rowwise()
- .sum()
- .replicate(1, right);
- // dy_var_dx
- auto dvar_end_part = (x_map - mean_map.replicate(1, right))
- .cwiseProduct(d_y_map)
- .rowwise()
- .sum();
- auto dvar_end = var_map.unaryExpr(inv_std_func)
- .unaryExpr(triple_product_func)
- .cwiseProduct(dvar_end_part)
- .replicate(1, right);
- auto dx_var =
- (T(-1.0) / right) *
- (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
-
- d_x_map = dx_end + dx_mean + dx_var;
- }
- }
- }
-};
-
} // namespace operators
} // namespace paddle
@@ -363,8 +167,9 @@ namespace ops = paddle::operators;
REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
layer_norm_grad, ops::LayerNormGradOp);
REGISTER_OP_CPU_KERNEL(
- layer_norm,
- ops::LayerNormKernel);
+ layer_norm, ops::LayerNormKernel,
+ ops::LayerNormKernel);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
- ops::LayerNormGradKernel);
+ ops::LayerNormGradKernel,
+ ops::LayerNormGradKernel);
diff --git a/paddle/operators/layer_norm_op.cu b/paddle/operators/layer_norm_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..77d13b216f0e8d6d4434742908437f1eb74818c9
--- /dev/null
+++ b/paddle/operators/layer_norm_op.cu
@@ -0,0 +1,25 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/layer_norm_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+ layer_norm,
+ ops::LayerNormKernel,
+ ops::LayerNormKernel);
+REGISTER_OP_CUDA_KERNEL(
+ layer_norm_grad,
+ ops::LayerNormGradKernel,
+ ops::LayerNormGradKernel);
diff --git a/paddle/operators/layer_norm_op.h b/paddle/operators/layer_norm_op.h
index bca35b91e6f52d35dee14aac9d080b52914942e3..3c436b89263758bbc0abcd1bb71cef3e1370d2a5 100644
--- a/paddle/operators/layer_norm_op.h
+++ b/paddle/operators/layer_norm_op.h
@@ -16,19 +16,222 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
+#include "paddle/operators/elementwise_op_function.h"
+#include "paddle/operators/math/math_function.h"
+
namespace paddle {
namespace operators {
+template
+struct SubAndSquareFunctor {
+ inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
+};
+
+template
+struct DivAndSqrtFunctor {
+ explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
+ inline HOSTDEVICE T operator()(T a, T b) const {
+ return a / (sqrt(b + epsilon_));
+ }
+
+ private:
+ T epsilon_;
+};
+
+template
+struct MulFunctor {
+ inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
+};
+
+template
+struct AddFunctor {
+ inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
+};
+
+template
+struct SubFunctor {
+ inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
+};
+
+template
+struct MulInvVarFunctor {
+ inline HOSTDEVICE T operator()(T a, T b) const {
+ return a * std::sqrt(1.0 / b);
+ }
+};
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+using DataLayout = framework::DataLayout;
+
template
class LayerNormKernel : public framework::OpKernel {
public:
- void Compute(const framework::ExecutionContext& ctx) const override;
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ const float epsilon = ctx.Attr("epsilon");
+ auto *scale = ctx.Input("Scale");
+ auto *bias = ctx.Input("Bias");
+ auto x = *ctx.Input("X");
+
+ auto *y = ctx.Output("Y");
+ auto *mean = ctx.Output("Mean");
+ auto *var = ctx.Output("Variance");
+ const auto begin_norm_axis = ctx.Attr("begin_norm_axis");
+
+ const auto x_dims = x.dims();
+
+ y->mutable_data(ctx.GetPlace());
+ mean->mutable_data