提交 01d20c44 编写于 作者: Y Yang Yu

Merge branch 'develop' of github.com:baidu/Paddle into feature/rnn_gradient_check

...@@ -22,6 +22,7 @@ On each machine, we will test and compare the performance of training on single ...@@ -22,6 +22,7 @@ On each machine, we will test and compare the performance of training on single
#### Training #### Training
Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
Pay attetion that the speed below includes forward, backward and parameter update time. So we can not directly compare the data with the benchmark of caffe `time` [command](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/caffe/image/run.sh#L9), which only contain forward and backward. The updating time of parameter would become very heavy when the weight size are large, especially on alexnet.
Input image size - 3 * 224 * 224, Time: images/second Input image size - 3 * 224 * 224, Time: images/second
...@@ -55,6 +56,16 @@ Input image size - 3 * 224 * 224, Time: images/second ...@@ -55,6 +56,16 @@ Input image size - 3 * 224 * 224, Time: images/second
<img src="figs/googlenet-cpu-train.png" width="500"> <img src="figs/googlenet-cpu-train.png" width="500">
- Alexnet
| BatchSize | 64 | 128 | 256 |
|--------------|--------| ------ | -------|
| OpenBLAS | 2.13 | 2.45 | 2.68 |
| MKLML | 66.37 | 105.60 | 144.04 |
| MKL-DNN | 399.00 | 498.94 | 626.53 |
chart TBD
#### Inference #### Inference
Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
- VGG-19 - VGG-19
......
# Design Doc: Add MKLDNN Kernel in Fluid Operator
## Principles
First of all, we should follow some basical principles like:
1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution
In general, there are four parts we should follow to run a MKL-DNN primitive.
- Create a primitive descriptor that describe this operator
- Create a primitive itself by primitive descriptor and the engine
- Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
### Compute
The algorithm of `Compute` would be described as follow, let's take conv like an example.
```c++
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
// find primitive by unique key from mkldnn context
// the op_key should be a unique name of this op instance
auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
// assuming the input tensor inside this compute function is the one after converted
// this point should be guarantee by another mechanism
auto& i = dev_ctx.findMemory(op_key + "_input");
if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
auto fwd_primitive_desc = createPrimitiveDesc(ctx);
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>()));
shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>()));
shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace())));
shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
dev_ctx.addMemory(op_key+"_input", in);
dev_ctx.addMemory(op_key+"_output", out);
dev_ctx.addMemory(op_key+"_filer", wgt);
dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
}
p = dev_ctx.findPrimitive(op_key + "_fwd");
PADDLE_ENFORCE(p, "Should have forward Primitive");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
dev_ctx.submit(p);
dev_ctx.execute(); // the convert primitive should have already contained.
```
The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this:
```c++
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
algorithm algo = static_cast<algorithm>(ctx.Attr<int>("convolution_algorithm_option"));
prop_kind pk = ctx.Attr<bool>("is_test") ? prop_kind::forward_inference : prop_kind::forward_training;
auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/);
shared_ptr<mkldnn::conv_fwd::primitive_desc> fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine()));
return fwd_primitive_desc;
}
```
### MKLDNNDeviceContext
`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed.
### mkldnn_helper
Some functions would be put in `paddle/platform/mkldnn_helper.h`.
- create MKLDNN memories
- create MKLDNN primitives
- error check function
- etc
### Kernel Switch
We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it.
`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`.
`trans` would be like this:
```c++
void trans(inputs, ctx) override {
if (NoNeedTrans()) {
return;
}
// find reorder primitive by op_key from context
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
auto& i = dev_ctx.findMemory(op_key + "_src_input");
if (p == nullptr || i == nullptr || changeSized(i, input)) {
auto prim = createPrimitiveDesc(ctx);
auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data);
auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes());
auto dst = createMemory(p->expected_desc(), newbuffer->data);
auto reorder_primitive(new mkldnn::reorder(src, dst));
dev_ctx.addMemory(op_key+"_src_input", src);
dev_ctx.addMemory(op_key+"_input", dst);
dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
}
p = dev_ctx.findPrimitive(op_key + "_reorder_input");
PADDLE_ENFORCE(p, "Should have Reorder Primitive");
dev_ctx.submit(p);
if (! this->isMKLDNNKernel()) {
// execute immediately only if this is not mkldnn kernel function.
// otherwise, it can be executed with the operator primitive in Compute
dev_ctx.stream();
}
// after submit, the input tensor in ExecutionContext should be changed as the converted one
// there should be another mechanism to ensure this
}
```
### Unit Test
All the functions should be tested corresponding.
TBD
...@@ -109,3 +109,31 @@ PaddlePaddle使用avx SIMD指令提高cpu执行效率,因此错误的使用二 ...@@ -109,3 +109,31 @@ PaddlePaddle使用avx SIMD指令提高cpu执行效率,因此错误的使用二
解决办法是: 解决办法是:
* 卸载PaddlePaddle包 :code:`pip uninstall paddle`, 清理掉老旧的PaddlePaddle安装包,使得单元测试有一个干净的环境。如果PaddlePaddle包已经在python的site-packages里面,单元测试会引用site-packages里面的python包,而不是源码目录里 :code:`/python` 目录下的python包。同时,即便设置 :code:`PYTHONPATH` 到 :code:`/python` 也没用,因为python的搜索路径是优先已经安装的python包。 * 卸载PaddlePaddle包 :code:`pip uninstall paddle`, 清理掉老旧的PaddlePaddle安装包,使得单元测试有一个干净的环境。如果PaddlePaddle包已经在python的site-packages里面,单元测试会引用site-packages里面的python包,而不是源码目录里 :code:`/python` 目录下的python包。同时,即便设置 :code:`PYTHONPATH` 到 :code:`/python` 也没用,因为python的搜索路径是优先已经安装的python包。
8. 下载MKLML库失败
------------------
.. code-block:: bash
make[2]: *** [third_party/mklml/src/extern_mklml-stamp/extern_mklml-download] 错误 4
make[1]: *** [CMakeFiles/extern_mklml.dir/all] 错误 2
make[1]: *** 正在等待未完成的任务....
原因:网速或SSL链接原因,导致MKLML库下载不成功。
解决办法是:手动下载并安装,具体步骤如下。
.. code-block:: bash
// 1. 进入对应的目录
cd build/third_party/mklml/src/extern_mklml
// 2. 查看包的大小, 正常情况下是75M,如果小于75M,即下载失败:
du -sh mklml_lnx_2018.0.1.20171007.tgz
// 3. 手动下载且解压缩,并手动生成download成功标签:
wget --no-check-certificate https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171007.tgz -c -O mklml_lnx_2018.0.1.20171007.tgz
tar zxf mklml_lnx_2018.0.1.20171007.tgz
touch ../extern_mklml-stamp/extern_mklml-download
// 4. 接着编译即可
...@@ -59,7 +59,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry ...@@ -59,7 +59,8 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(threadpool_test SRCS threadpool_test.cc) cc_library(threadpool SRCS threadpool.cc)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/platform/enforce.h"
#include <iostream> #include <iostream>
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -20,7 +21,7 @@ limitations under the License. */ ...@@ -20,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum DataLayout { enum class DataLayout {
kNHWC = 0, kNHWC = 0,
kNCHW = 1, kNCHW = 1,
kAnyLayout = 2, kAnyLayout = 2,
...@@ -38,11 +39,11 @@ inline DataLayout StringToDataLayout(const std::string& str) { ...@@ -38,11 +39,11 @@ inline DataLayout StringToDataLayout(const std::string& str) {
inline std::string DataLayoutToString(const DataLayout& data_layout) { inline std::string DataLayoutToString(const DataLayout& data_layout) {
switch (data_layout) { switch (data_layout) {
case kNHWC: case DataLayout::kNHWC:
return "NHWC"; return "NHWC";
case kNCHW: case DataLayout::kNCHW:
return "NCHW"; return "NCHW";
case kAnyLayout: case DataLayout::kAnyLayout:
return "ANY_LAYOUT"; return "ANY_LAYOUT";
default: default:
PADDLE_THROW("unknown DataLayou %d", data_layout); PADDLE_THROW("unknown DataLayou %d", data_layout);
......
...@@ -20,15 +20,15 @@ namespace framework { ...@@ -20,15 +20,15 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to // For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 }; enum class LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
inline std::string LibraryTypeToString(const LibraryType& library_type) { inline std::string LibraryTypeToString(const LibraryType& library_type) {
switch (library_type) { switch (library_type) {
case kPlain: case LibraryType::kPlain:
return "PLAIN"; return "PLAIN";
case kMKLDNN: case LibraryType::kMKLDNN:
return "MKLDNN"; return "MKLDNN";
case kCUDNN: case LibraryType::kCUDNN:
return "CUDNN"; return "CUDNN";
default: default:
PADDLE_THROW("unknown LibraryType %d", library_type); PADDLE_THROW("unknown LibraryType %d", library_type);
......
...@@ -40,6 +40,7 @@ struct OpKernelType { ...@@ -40,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_; proto::DataType data_type_;
DataLayout data_layout_; DataLayout data_layout_;
platform::Place place_; platform::Place place_;
......
...@@ -20,12 +20,12 @@ limitations under the License. */ ...@@ -20,12 +20,12 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
...@@ -115,6 +115,10 @@ class Tensor { ...@@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const; inline void check_memory_size() const;
inline DataLayout layout() const { return layout_; }
inline void set_layout(const DataLayout layout) { layout_ = layout; }
private: private:
friend class LoDTensor; friend class LoDTensor;
...@@ -173,6 +177,19 @@ class Tensor { ...@@ -173,6 +177,19 @@ class Tensor {
DDim dims_; DDim dims_;
/**
* @brief the layout of memory block, default is NCHW.
*
* @note the memory allocation order, describe how weight/data is stored
* For example, in 4-D Tensor(rank=4), there are three commonly
* used layout. They are
* NCHW, NHWC, CHWN.
* N,C,H,W for respectively the batch size, the number of
* feature maps, the height.
*/
DataLayout layout_ = DataLayout::kNHWC;
/** /**
* @brief A PlaceHolder may be shared by more than one tensor. * @brief A PlaceHolder may be shared by more than one tensor.
* *
......
...@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const { ...@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0]; size_t base = numel() / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
dst.set_layout(layout_);
DDim dst_dims = dims_; DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx; dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims); dst.Resize(dst_dims);
......
...@@ -200,3 +200,12 @@ TEST(Tensor, ReshapeToMatrix) { ...@@ -200,3 +200,12 @@ TEST(Tensor, ReshapeToMatrix) {
ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9); ASSERT_EQ(res.dims()[1], 4 * 9);
} }
TEST(Tensor, Layout) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
ASSERT_EQ(src.layout(), DataLayout::kNHWC);
src.set_layout(DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), DataLayout::kAnyLayout);
}
...@@ -33,6 +33,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, ...@@ -33,6 +33,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
...@@ -89,6 +90,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place, ...@@ -89,6 +90,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) { Tensor* dst) {
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
......
...@@ -28,6 +28,7 @@ TEST(CopyFrom, Tensor) { ...@@ -28,6 +28,7 @@ TEST(CopyFrom, Tensor) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int)); memcpy(src_ptr, arr, 9 * sizeof(int));
src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace(); auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, &dst_tensor); CopyFrom(src_tensor, *cpu_place, &dst_tensor);
...@@ -38,14 +39,18 @@ TEST(CopyFrom, Tensor) { ...@@ -38,14 +39,18 @@ TEST(CopyFrom, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor); CopyFrom(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data<int>(); const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>(); dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr); ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
...@@ -91,6 +96,8 @@ TEST(CopyFrom, Tensor) { ...@@ -91,6 +96,8 @@ TEST(CopyFrom, Tensor) {
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
} }
#endif #endif
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
std::once_flag ThreadPool::init_flag;
} // namespace framework
} // namespace paddle
...@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <condition_variable> #include <condition_variable>
#include <cstdio>
#include <functional> #include <functional>
#include <iostream>
#include <mutex> #include <mutex>
#include <queue> #include <queue>
#include <thread> #include <thread>
#include "paddle/platform/call_once.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -81,10 +79,9 @@ class ThreadPool { ...@@ -81,10 +79,9 @@ class ThreadPool {
} }
private: private:
ThreadPool& operator=(const ThreadPool&) = delete; DISABLE_COPY_AND_ASSIGN(ThreadPool);
ThreadPool(const ThreadPool&) = delete;
ThreadPool(int num_threads) explicit ThreadPool(int num_threads)
: num_threads_(num_threads), available_(num_threads), running_(true) { : num_threads_(num_threads), available_(num_threads), running_(true) {
threads_.resize(num_threads); threads_.resize(num_threads);
for (auto& thread : threads_) { for (auto& thread : threads_) {
...@@ -155,7 +152,5 @@ class ThreadPool { ...@@ -155,7 +152,5 @@ class ThreadPool {
std::condition_variable completed_; std::condition_variable completed_;
}; };
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
std::once_flag ThreadPool::init_flag;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,12 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "threadpool.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <atomic> #include <atomic>
#include <chrono>
#include <map> #include "threadpool.h"
#include <thread>
namespace framework = paddle::framework; namespace framework = paddle::framework;
......
...@@ -13,59 +13,113 @@ See the License for the specific language governing permissions and ...@@ -13,59 +13,113 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/for_range.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct AdamFunctor {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
T* mom2_out, const T* lr, const T* grad, const T* param,
T* param_out)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// Merge all memory access together.
T g = grad_[i];
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[i];
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); using paddle::framework::LoDTensor;
auto moment1_out_tensor = ctx.Output<framework::Tensor>("Moment1Out"); using paddle::operators::detail::Ref;
auto moment2_out_tensor = ctx.Output<framework::Tensor>("Moment2Out");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
auto& lr =
Ref(ctx.Input<LoDTensor>("LearningRate"), "Must set LearningRate");
auto& beta1_pow =
Ref(ctx.Input<LoDTensor>("Beta1Pow"), "Must set Beta1Pow");
auto& beta2_pow =
Ref(ctx.Input<LoDTensor>("Beta2Pow"), "Must set Beta2Pow");
auto& param_out =
Ref(ctx.Output<LoDTensor>("ParamOut"), "Must set ParamOut");
auto& mom1_out =
Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out");
auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
auto param = framework::EigenVector<T>::Flatten( AdamFunctor<T> functor(beta1, beta2, epsilon, beta1_pow.template data<T>(),
*ctx.Input<framework::Tensor>("Param")); beta2_pow.template data<T>(),
auto grad = framework::EigenVector<T>::Flatten( mom1.template data<T>(),
*ctx.Input<framework::Tensor>("Grad")); mom1_out.template mutable_data<T>(ctx.GetPlace()),
auto moment1 = framework::EigenVector<T>::Flatten( mom2.template data<T>(),
*ctx.Input<framework::Tensor>("Moment1")); mom2_out.template mutable_data<T>(ctx.GetPlace()),
auto moment2 = framework::EigenVector<T>::Flatten( lr.template data<T>(), grad.template data<T>(),
*ctx.Input<framework::Tensor>("Moment2")); param.template data<T>(),
auto lr = framework::EigenVector<T>::Flatten( param_out.template mutable_data<T>(ctx.GetPlace()));
*ctx.Input<framework::Tensor>("LearningRate")); platform::ForRange<DeviceContext> for_range(
auto beta1_pow = framework::EigenVector<T>::Flatten( static_cast<const DeviceContext&>(ctx.device_context()), param.numel());
*ctx.Input<framework::Tensor>("Beta1Pow")); for_range(functor);
auto beta2_pow = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Beta2Pow"));
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
moment1_out.device(*place) = beta1 * moment1 + (1 - beta1) * grad;
moment2_out.device(*place) = beta2 * moment2 + (1 - beta2) * grad.square();
// All of these are tensors of 1 element
auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow);
// Eigen does not support automatic broadcast
// Get dimensions of moment vector to broadcast lr_t
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
param_out.device(*place) =
param -
lr_t.broadcast(m_dsize) *
(moment1_out / (moment2_out.sqrt() + epsilon));
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
...@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -102,9 +103,12 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gate_activation"); auto gate_act = math::detail::GetActivationType(
auto cell_act = ctx.Attr<std::string>("cell_activation"); ctx.Attr<std::string>("gate_activation"));
auto cand_act = ctx.Attr<std::string>("candidate_activation"); auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
...@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -264,9 +268,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gate_activation"); auto gate_act = math::detail::GetActivationType(
auto cell_act = ctx.Attr<std::string>("cell_activation"); ctx.Attr<std::string>("gate_activation"));
auto cand_act = ctx.Attr<std::string>("candidate_activation"); auto cell_act = math::detail::GetActivationType(
ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation"));
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <math.h> #include <math.h>
#include "paddle/platform/enforce.h"
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
#ifdef __AVX__ #ifdef __AVX__
...@@ -29,6 +30,26 @@ namespace detail { ...@@ -29,6 +30,26 @@ namespace detail {
#define SIGMOID_THRESHOLD_MAX 13.0 #define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0 #define EXP_MAX_INPUT 40.0
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity" || type == "") {
return ActivationType::kIdentity;
}
PADDLE_THROW("Not support type %s.", type);
}
namespace forward { namespace forward {
template <typename T> template <typename T>
......
...@@ -26,10 +26,9 @@ namespace detail { ...@@ -26,10 +26,9 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, int frame_size, ActivationType active_node,
activation_mode_t active_node, ActivationType active_gate,
activation_mode_t active_gate, ActivationType active_state) {
activation_mode_t active_state) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -77,9 +76,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -77,9 +76,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
T r_value_in; T r_value_in;
T r_value_ig; T r_value_ig;
T r_value_fg; T r_value_fg;
...@@ -149,10 +148,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -149,10 +148,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, int frame_size, ActivationType active_node,
activation_mode_t active_node, ActivationType active_gate,
activation_mode_t active_gate, ActivationType active_state) {
activation_mode_t active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -204,9 +202,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -204,9 +202,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_in; __m256 r_value_in;
__m256 r_value_ig; __m256 r_value_ig;
...@@ -281,9 +279,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -281,9 +279,8 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
activation_mode_t active_node, ActivationType active_node, ActivationType active_gate,
activation_mode_t active_gate, ActivationType active_state) {
activation_mode_t active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node, avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
...@@ -295,9 +292,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -295,9 +292,9 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, activation_mode_t active_node, int frame_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node,
active_gate, active_state); active_gate, active_state);
......
...@@ -31,9 +31,9 @@ namespace detail { ...@@ -31,9 +31,9 @@ namespace detail {
*/ */
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, activation_mode_t active_node, int batch_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -91,9 +91,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
int batch_size, activation_mode_t active_node, int batch_size, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -185,9 +185,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -185,9 +185,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op, void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node, ActivationType active_gate,
activation_mode_t active_gate, ActivationType active_state) {
activation_mode_t active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
...@@ -220,9 +219,8 @@ template <class T, class Op> ...@@ -220,9 +219,8 @@ template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op, void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node, ActivationType active_gate,
activation_mode_t active_gate, ActivationType active_state) {
activation_mode_t active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
......
...@@ -30,9 +30,9 @@ class lstm { ...@@ -30,9 +30,9 @@ class lstm {
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output, T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO, T &checkI, T &checkF, T &checkO,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate); value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate); value_fg = activation(value_fg + prev_state * checkF, active_gate);
...@@ -53,9 +53,9 @@ class lstm { ...@@ -53,9 +53,9 @@ class lstm {
__m256 &prev_state, __m256 &state, __m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI, __m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO, __m256 &checkF, __m256 &checkO,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); value_in = activation(value_in, active_node);
value_ig = value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
...@@ -87,9 +87,9 @@ class lstm { ...@@ -87,9 +87,9 @@ class lstm {
T &state_grad, T &state_atv, T &output_grad, T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad, T &checkFGrad, T &checkOGrad,
activation_mode_t active_node, ActivationType active_node,
activation_mode_t active_gate, ActivationType active_gate,
activation_mode_t active_state) { ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate); grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) + state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO; grad_og * checkO;
...@@ -114,8 +114,8 @@ class lstm { ...@@ -114,8 +114,8 @@ class lstm {
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, activation_mode_t active_node, __m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node,
activation_mode_t active_gate, activation_mode_t active_state) { ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate); active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
......
...@@ -24,12 +24,12 @@ template <class T> ...@@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> { struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
ActiveType(cand_act), ActiveType(gate_act), cand_act, gate_act, cell_act);
ActiveType(cell_act));
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { ...@@ -46,12 +46,12 @@ struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, ActiveType(cand_act), frame_size, cand_act, gate_act, cell_act);
ActiveType(gate_act), ActiveType(cell_act));
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
......
...@@ -24,11 +24,12 @@ template <class T> ...@@ -24,11 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> { struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, cand_act, gate_act,
ActiveType(gate_act), ActiveType(cell_act)); cell_act);
} }
}; };
...@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> { ...@@ -37,11 +38,12 @@ struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string& gate_act, const std::string& cell_act, const detail::ActivationType& gate_act,
const std::string& cand_act) { const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, ActiveType(cand_act), frame_size, batch_size, cand_act, gate_act,
ActiveType(gate_act), ActiveType(cell_act)); cell_act);
} }
}; };
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -72,8 +73,9 @@ class LstmUnitFunctor { ...@@ -72,8 +73,9 @@ class LstmUnitFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act, const detail::ActivationType &gate_act,
const std::string &cand_act); const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -81,8 +83,9 @@ class LstmUnitGradFunctor { ...@@ -81,8 +83,9 @@ class LstmUnitGradFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaGrad<T> grad, int frame_size, int batch_size,
const std::string &gate_act, const std::string &cell_act, const detail::ActivationType &gate_act,
const std::string &cand_act); const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act);
}; };
} // namespace math } // namespace math
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/device_context.h"
namespace paddle {
namespace platform {
template <typename DeviceContext>
struct ForRange {
ForRange(const DeviceContext& dev_ctx, size_t limit);
template <typename Function>
void operator()(Function func) const;
};
template <>
struct ForRange<CPUDeviceContext> {
ForRange(const CPUDeviceContext& dev_ctx, size_t limit) : limit_(limit) {}
template <typename Function>
void operator()(Function func) const {
for (size_t i = 0; i < limit_; ++i) {
func(i);
}
}
size_t limit_;
};
#ifdef __NVCC__
template <typename Function>
__global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
size_t idx = static_cast<size_t>(threadIdx.x);
func(idx);
}
template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
}
}
template <>
struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr size_t num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads;
if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
func);
} else {
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
func, limit_);
}
}
const CUDADeviceContext& dev_ctx_;
int limit_;
};
#endif
} // namespace platform
} // namespace paddle
...@@ -78,6 +78,10 @@ PYBIND11_PLUGIN(core) { ...@@ -78,6 +78,10 @@ PYBIND11_PLUGIN(core) {
[](Tensor &self, const std::vector<int64_t> &dim) { [](Tensor &self, const std::vector<int64_t> &dim) {
self.Resize(make_ddim(dim)); self.Resize(make_ddim(dim));
}) })
.def("set_layout",
[](Tensor &self, const std::string &layout) {
self.set_layout(StringToDataLayout(layout));
})
.def("alloc_float", .def("alloc_float",
[](Tensor &self, paddle::platform::CUDAPlace &place) { [](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
__all__ = [ __all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle', 'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn', 'xmap_readers', 'pipe_reader' 'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
] ]
from threading import Thread from threading import Thread
...@@ -334,93 +334,72 @@ def _buf2lines(buf, line_break="\n"): ...@@ -334,93 +334,72 @@ def _buf2lines(buf, line_break="\n"):
return lines[:-1], lines[-1] return lines[:-1], lines[-1]
def pipe_reader(left_cmd, class PipeReader:
parser,
bufsize=8192,
file_type="plain",
cut_lines=True,
line_break="\n"):
""" """
pipe_reader read data by stream from a command, take it's PipeReader read data by stream from a command, take it's
stdout into a pipe buffer and redirect it to the parser to stdout into a pipe buffer and redirect it to the parser to
parse, then yield data as your desired format. parse, then yield data as your desired format.
You can using standard linux command or call another program You can using standard linux command or call another program
to read data, from HDFS, Ceph, URL, AWS S3 etc: to read data, from HDFS, Ceph, URL, AWS S3 etc:
cmd = "hadoop fs -cat /path/to/some/file" .. code-block:: python
cmd = "cat sample_file.tar.gz" cmd = "hadoop fs -cat /path/to/some/file"
cmd = "curl http://someurl" cmd = "cat sample_file.tar.gz"
cmd = "python print_s3_bucket.py" cmd = "curl http://someurl"
cmd = "python print_s3_bucket.py"
A sample parser: An example:
.. code-block:: python
def sample_parser(lines): def example_reader():
# parse each line as one sample data, for f in myfiles:
# return a list of samples as batches. pr = PipeReader("cat %s"%f)
ret = [] for l in pr.get_line():
for l in lines: sample = l.split(" ")
ret.append(l.split(" ")[1:5]) yield sample
return ret
:param left_cmd: command to excute to get stdout from.
:type left_cmd: string
:param parser: parser function to parse lines of data.
if cut_lines is True, parser will receive list
of lines.
if cut_lines is False, parser will receive a
raw buffer each time.
parser should return a list of parsed values.
:type parser: callable
:param bufsize: the buffer size used for the stdout pipe.
:type bufsize: int
:param file_type: can be plain/gzip, stream buffer data type.
:type file_type: string
:param cut_lines: whether to pass lines instead of raw buffer
to the parser
:type cut_lines: bool
:param line_break: line break of the file, like \n or \r
:type line_break: string
:return: the reader generator.
:rtype: callable
""" """
if not isinstance(left_cmd, str):
raise TypeError("left_cmd must be a string")
if not callable(parser):
raise TypeError("parser must be a callable object")
# TODO(typhoonzero): add a thread to read stderr
# Always init a decompress object is better than
# create in the loop.
dec = zlib.decompressobj(
32 + zlib.MAX_WBITS) # offset 32 to skip the header
def reader(): def __init__(self, command, bufsize=8192, file_type="plain"):
process = subprocess.Popen( if not isinstance(command, str):
left_cmd.split(" "), bufsize=bufsize, stdout=subprocess.PIPE) raise TypeError("left_cmd must be a string")
if file_type == "gzip":
self.dec = zlib.decompressobj(
32 + zlib.MAX_WBITS) # offset 32 to skip the header
self.file_type = file_type
self.bufsize = bufsize
self.process = subprocess.Popen(
command.split(" "), bufsize=bufsize, stdout=subprocess.PIPE)
def get_line(self, cut_lines=True, line_break="\n"):
"""
:param cut_lines: cut buffer to lines
:type cut_lines: bool
:param line_break: line break of the file, like \n or \r
:type line_break: string
:return: one line or a buffer of bytes
:rtype: string
"""
remained = "" remained = ""
while True: while True:
buff = process.stdout.read(bufsize) buff = self.process.stdout.read(self.bufsize)
if buff: if buff:
if file_type == "gzip": if self.file_type == "gzip":
decomp_buff = dec.decompress(buff) decomp_buff = self.dec.decompress(buff)
elif file_type == "plain": elif self.file_type == "plain":
decomp_buff = buff decomp_buff = buff
else: else:
raise TypeError("file_type %s is not allowed" % file_type) raise TypeError("file_type %s is not allowed" %
self.file_type)
if cut_lines: if cut_lines:
lines, remained = _buf2lines(''.join( lines, remained = _buf2lines(''.join(
[remained, decomp_buff]), line_break) [remained, decomp_buff]), line_break)
parsed_list = parser(lines) for line in lines:
for ret in parsed_list: yield line
yield ret
else: else:
for ret in parser(decomp_buff): yield decomp_buff
yield ret
else: else:
break break
return reader
...@@ -147,8 +147,11 @@ class TestXmap(unittest.TestCase): ...@@ -147,8 +147,11 @@ class TestXmap(unittest.TestCase):
class TestPipeReader(unittest.TestCase): class TestPipeReader(unittest.TestCase):
def test_pipe_reader(self): def test_pipe_reader(self):
def simple_parser(lines): def example_reader(myfiles):
return lines for f in myfiles:
pr = paddle.v2.reader.PipeReader("cat %s" % f, bufsize=128)
for l in pr.get_line():
yield l
import tempfile import tempfile
...@@ -159,17 +162,12 @@ class TestPipeReader(unittest.TestCase): ...@@ -159,17 +162,12 @@ class TestPipeReader(unittest.TestCase):
for r in records: for r in records:
f.write('%s\n' % r) f.write('%s\n' % r)
cmd = "cat %s" % temp.name result = []
reader = paddle.v2.reader.pipe_reader( for r in example_reader([temp.name]):
cmd, simple_parser, bufsize=128) result.append(r)
for i in xrange(4):
result = [] for idx, e in enumerate(records):
for r in reader(): self.assertEqual(e, result[idx])
result.append(r)
for idx, e in enumerate(records):
print e, result[idx]
self.assertEqual(e, result[idx])
finally: finally:
# delete the temporary file # delete the temporary file
temp.close() temp.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册