diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake index 219ea1b90881ccdbaf3fd41510fb4f2a8b6ec0f4..86122aec8c77f34756a37121582b92489d749d7f 100644 --- a/cmake/external/grpc.cmake +++ b/cmake/external/grpc.cmake @@ -24,9 +24,9 @@ SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc) SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE) SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE) IF(APPLE) - SET(BUILD_CMD make -n | sed "s/-Werror//g" | sh) + SET(BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j8 static grpc_cpp_plugin | sed "s/-Werror//g" | sh) ELSE() - SET(BUILD_CMD make) + SET(BUILD_CMD make HAS_SYSTEM_PROTOBUF=false -s -j8 static grpc_cpp_plugin) ENDIF() ExternalProject_Add( @@ -42,7 +42,7 @@ ExternalProject_Add( # Disable -Werror, otherwise the compile will fail in MacOS. # It seems that we cannot configure that by make command. # Just dry run make command and remove `-Werror`, then use a shell to run make commands - BUILD_COMMAND ${BUILD_CMD} HAS_SYSTEM_PROTOBUF=false -s -j8 static grpc_cpp_plugin + BUILD_COMMAND ${BUILD_CMD} INSTALL_COMMAND make prefix=${GRPC_INSTALL_DIR} install ) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index c917ca0ff4e087b7caae8876da127bec6b39b798..9cf256fb6d5336b92f94198517f700e77a9330b3 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -227,8 +227,8 @@ function(cc_test TARGET_NAME) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_executable(${TARGET_NAME} ${cc_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main) - add_dependencies(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main) + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() endfunction(cc_test) @@ -288,8 +288,8 @@ function(nv_test TARGET_NAME) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} gtest gtest_main) - add_dependencies(${TARGET_NAME} ${nv_test_DEPS} gtest gtest_main) + target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) add_test(${TARGET_NAME} ${TARGET_NAME}) endif() endfunction(nv_test) diff --git a/doc/design/float16.md b/doc/design/float16.md index 078801ba2ed969d26dd31d5ec4ed268686cf7016..1ea95ed6b5d6792171569b6ff76d09be92fcb13e 100644 --- a/doc/design/float16.md +++ b/doc/design/float16.md @@ -28,6 +28,51 @@ The goal of float16 is to serve as a key for the executor to find and run the co - [Eigen](https://github.com/RLovelett/eigen) >= 3.3 supports float16 calculation on both GPU and CPU using the `Eigen::half` class. It is mostly useful for Nvidia GPUs because of the overloaded arithmetic operators using cuda intrinsics. It falls back to using software emulation on CPU for calculation and there is no special treatment to ARM processors. - [ARM compute library](https://github.com/ARM-software/ComputeLibrary) >= 17.02.01 supports NEON FP16 kernels (requires ARMv8.2-A CPU). +### CUDA version issue +There are currently three versions of CUDA that supports `__half` data type, namely, CUDA 7.5, 8.0, and 9.0. +CUDA 7.5 and 8.0 define `__half` as a simple struct that has a `uint16_t` data (see [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/9212ab5a3ddbe48f30ef373f9c1fb546804c7a8c/include/isaac/external/CUDA/cuda_fp16.h)) as follows: +``` +typedef struct __align__(2) { + unsigned short x; +} __half; + +typedef __half half; +``` +This struct does not define any overloaded arithmetic operators. So you have to directly use `__hadd` instead of `+` to correctly add two half types: +``` +__global__ void Add() { + half a, b, c; + c = __hadd(a, b); // correct + c = a + b; // compiler error: no operator "+" matches these operands +} +``` +CUDA 9.0 provides a major update to the half data type. The related code can be found in the updated [`cuda_fp16.h`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.h) and the newly added [`cuda_fp16.hpp`](https://github.com/ptillet/isaac/blob/master/include/isaac/external/CUDA/cuda_fp16.hpp). + +Essentially, CUDA 9.0 renames the original `__half` type in 7.5 and 8.0 as `__half_raw`, and defines a new `__half` class type that has constructors, conversion operators, and also provides overloaded arithmetic operators such as follows: +``` +typedef struct __CUDA_ALIGN__(2) { + unsigned short x; +} __half_raw; + + +struct __CUDA_ALIGN__(2) __half { +protected: + unsigned short __x; +public: + // constructors and conversion operators from/to + // __half_raw and other built-in data types +} + +typedef __half half; + +__device__ __forceinline__ +__half operator+(const __half &lh, const __half &rh) { + return __hadd(lh, rh); +} + +// Other overloaded operators +``` +This new design makes `c = a + b` work correctly for CUDA half data type. ## Implementation The float16 class holds a 16-bit `uint16_t` data internally. diff --git a/doc/howto/optimization/cpu_profiling.md b/doc/howto/optimization/cpu_profiling.md index 32d89a7c183d57e0e69039dfb2c78703d9866f7c..e1d91c668e9c6711cc3a529168c8a3f6338de59d 100644 --- a/doc/howto/optimization/cpu_profiling.md +++ b/doc/howto/optimization/cpu_profiling.md @@ -1,42 +1,52 @@ -此教程会介绍如何使用Python的cProfile包,与Python库yep,google perftools来运行性能分析(Profiling)与调优。 +This tutorial introduces techniques we used to profile and tune the +CPU performance of PaddlePaddle. We will use Python packages +`cProfile` and `yep`, and Google `perftools`. -运行性能分析可以让开发人员科学的,有条不紊的对程序进行性能优化。性能分析是性能调优的基础。因为在程序实际运行中,真正的瓶颈可能和程序员开发过程中想象的瓶颈相去甚远。 +Profiling is the process that reveals the performance bottlenecks, +which could be very different from what's in the developers' mind. +Performance tuning is to fix the bottlenecks. Performance optimization +repeats the steps of profiling and tuning alternatively. -性能优化的步骤,通常是循环重复若干次『性能分析 --> 寻找瓶颈 ---> 调优瓶颈 --> 性能分析确认调优效果』。其中性能分析是性能调优的至关重要的量化指标。 +PaddlePaddle users program AI by calling the Python API, which calls +into `libpaddle.so.` written in C++. In this tutorial, we focus on +the profiling and tuning of -Paddle提供了Python语言绑定。用户使用Python进行神经网络编程,训练,测试。Python解释器通过`pybind`和`swig`调用Paddle的动态链接库,进而调用Paddle C++部分的代码。所以Paddle的性能分析与调优分为两个部分: +1. the Python code and +1. the mixture of Python and C++ code. -* Python代码的性能分析 -* Python与C++混合代码的性能分析 +## Profiling the Python Code +### Generate the Performance Profiling File -## Python代码的性能分析 - -### 生成性能分析文件 - -Python标准库中提供了性能分析的工具包,[cProfile](https://docs.python.org/2/library/profile.html)。生成Python性能分析的命令如下: +We can use Python standard +package, [`cProfile`](https://docs.python.org/2/library/profile.html), +to generate Python profiling file. For example: ```bash python -m cProfile -o profile.out main.py ``` -其中`-o`标识了一个输出的文件名,用来存储本次性能分析的结果。如果不指定这个文件,`cProfile`会打印一些统计信息到`stdout`。这不方便我们进行后期处理(进行`sort`, `split`, `cut`等等)。 - -### 查看性能分析文件 +where `main.py` is the program we are going to profile, `-o` specifies +the output file. Without `-o`, `cProfile` would outputs to standard +output. -当main.py运行完毕后,性能分析结果文件`profile.out`就生成出来了。我们可以使用[cprofilev](https://github.com/ymichael/cprofilev)来查看性能分析结果。`cprofilev`是一个Python的第三方库。使用它会开启一个HTTP服务,将性能分析结果以网页的形式展示出来。 +### Look into the Profiling File -使用`pip install cprofilev`安装`cprofilev`工具。安装完成后,使用如下命令开启HTTP服务 +`cProfile` generates `profile.out` after `main.py` completes. We can +use [`cprofilev`](https://github.com/ymichael/cprofilev) to look into +the details: ```bash cprofilev -a 0.0.0.0 -p 3214 -f profile.out main.py ``` -其中`-a`标识HTTP服务绑定的IP。使用`0.0.0.0`允许外网访问这个HTTP服务。`-p`标识HTTP服务的端口。`-f`标识性能分析的结果文件。`main.py`标识被性能分析的源文件。 +where `-a` specifies the HTTP IP, `-p` specifies the port, `-f` +specifies the profiling file, and `main.py` is the source file. -访问对应网址,即可显示性能分析的结果。性能分析结果格式如下: +Open the Web browser and points to the local IP and the specifies +port, we will see the output like the following: -```text +``` ncalls tottime percall cumtime percall filename:lineno(function) 1 0.284 0.284 29.514 29.514 main.py:1() 4696 0.128 0.000 15.748 0.003 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/executor.py:20(run) @@ -44,23 +54,23 @@ cprofilev -a 0.0.0.0 -p 3214 -f profile.out main.py 1 0.144 0.144 6.534 6.534 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/__init__.py:14() ``` -每一列的含义是: +where each line corresponds to Python function, and the meaning of +each column is as follows: -| 列名 | 含义 | +| column | meaning | | --- | --- | -| ncalls | 函数的调用次数 | -| tottime | 函数实际使用的总时间。该时间去除掉本函数调用其他函数的时间 | -| percall | tottime的每次调用平均时间 | -| cumtime | 函数总时间。包含这个函数调用其他函数的时间 | -| percall | cumtime的每次调用平均时间 | -| filename:lineno(function) | 文件名, 行号,函数名 | +| ncalls | the number of calls into a function | +| tottime | the total execution time of the function, not including the + execution time of other functions called by the function | +| percall | tottime divided by ncalls | +| cumtime | the total execution time of the function, including the execution time of other functions being called | +| percall | cumtime divided by ncalls | +| filename:lineno(function) | where the function is defined | +### Identify Performance Bottlenecks -### 寻找性能瓶颈 - -通常`tottime`和`cumtime`是寻找瓶颈的关键指标。这两个指标代表了某一个函数真实的运行时间。 - -将性能分析结果按照tottime排序,效果如下: +Usually, `tottime` and the related `percall` time is what we want to +focus on. We can sort above profiling file by tottime: ```text 4696 12.040 0.003 12.040 0.003 {built-in method run} @@ -68,12 +78,15 @@ cprofilev -a 0.0.0.0 -p 3214 -f profile.out main.py 107991 0.676 0.000 1.519 0.000 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:219(__init__) 4697 0.626 0.000 2.291 0.000 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:428(sync_with_cpp) 1 0.618 0.618 0.618 0.618 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/__init__.py:1() - ``` -可以看到最耗时的函数是C++端的`run`函数。这需要联合我们第二节`Python与C++混合代码的性能分析`来进行调优。而`sync_with_cpp`函数的总共耗时很长,每次调用的耗时也很长。于是我们可以点击`sync_with_cpp`的详细信息,了解其调用关系。 +We can see that the most time-consuming function is the `built-in +method run`, which is a C++ function in `libpaddle.so`. We will +explain how to profile C++ code in the next section. At the right +moment, let's look into the third function `sync_with_cpp`, which is a +Python function. We can click it to understand more about it: -```text +``` Called By: Ordered by: internal time @@ -92,72 +105,93 @@ Called: List reduced from 4497 to 2 due to restriction <'sync_with_cpp'> ``` -通常观察热点函数间的调用关系,和对应行的代码,就可以了解到问题代码在哪里。当我们做出性能修正后,再次进行性能分析(profiling)即可检查我们调优后的修正是否能够改善程序的性能。 +The lists of the callers of `sync_with_cpp` might help us understand +how to improve the function definition. +## Profiling Python and C++ Code +### Generate the Profiling File -## Python与C++混合代码的性能分析 +To profile a mixture of Python and C++ code, we can use a Python +package, `yep`, that can work with Google's `perftools`, which is a +commonly-used profiler for C/C++ code. -### 生成性能分析文件 - -C++的性能分析工具非常多。常见的包括`gprof`, `valgrind`, `google-perftools`。但是调试Python中使用的动态链接库与直接调试原始二进制相比增加了很多复杂度。幸而Python的一个第三方库`yep`提供了方便的和`google-perftools`交互的方法。于是这里使用`yep`进行Python与C++混合代码的性能分析 - -使用`yep`前需要安装`google-perftools`与`yep`包。ubuntu下安装命令为 +In Ubuntu systems, we can install `yep` and `perftools` by running the +following commands: ```bash +apt update apt install libgoogle-perftools-dev pip install yep ``` -安装完毕后,我们可以通过 +Then we can run the following command ```bash python -m yep -v main.py ``` -生成性能分析文件。生成的性能分析文件为`main.py.prof`。 +to generate the profiling file. The default filename is +`main.py.prof`. + +Please be aware of the `-v` command line option, which prints the +analysis results after generating the profiling file. By taking a +glance at the print result, we'd know that if we stripped debug +information from `libpaddle.so` at build time. The following hints +help make sure that the analysis results are readable: -命令行中的`-v`指定在生成性能分析文件之后,在命令行显示分析结果。我们可以在命令行中简单的看一下生成效果。因为C++与Python不同,编译时可能会去掉调试信息,运行时也可能因为多线程产生混乱不可读的性能分析结果。为了生成更可读的性能分析结果,可以采取下面几点措施: +1. Use GCC command line option `-g` when building `libpaddle.so` so to + include the debug information. The standard building system of + PaddlePaddle is CMake, so you might want to set + `CMAKE_BUILD_TYPE=RelWithDebInfo`. -1. 编译时指定`-g`生成调试信息。使用cmake的话,可以将CMAKE_BUILD_TYPE指定为`RelWithDebInfo`。 -2. 编译时一定要开启优化。单纯的`Debug`编译性能会和`-O2`或者`-O3`有非常大的差别。`Debug`模式下的性能测试是没有意义的。 -3. 运行性能分析的时候,先从单线程开始,再开启多线程,进而多机。毕竟如果单线程调试更容易。可以设置`OMP_NUM_THREADS=1`这个环境变量关闭openmp优化。 +1. Use GCC command line option `-O2` or `-O3` to generate optimized + binary code. It doesn't make sense to profile `libpaddle.so` + without optimization, because it would anyway run slowly. -### 查看性能分析文件 +1. Profiling the single-threaded binary file before the + multi-threading version, because the latter often generates tangled + profiling analysis result. You might want to set environment + variable `OMP_NUM_THREADS=1` to prevents OpenMP from automatically + starting multiple threads. -在运行完性能分析后,会生成性能分析结果文件。我们可以使用[pprof](https://github.com/google/pprof)来显示性能分析结果。注意,这里使用了用`Go`语言重构后的`pprof`,因为这个工具具有web服务界面,且展示效果更好。 +### Look into the Profiling File -安装`pprof`的命令和一般的`Go`程序是一样的,其命令如下: +The tool we used to look into the profiling file generated by +`perftools` is [`pprof`](https://github.com/google/pprof), which +provides a Web-based GUI like `cprofilev`. + +We can rely on the standard Go toolchain to retrieve the source code +of `pprof` and build it: ```bash go get github.com/google/pprof ``` -进而我们可以使用如下命令开启一个HTTP服务: +Then we can use it to profile `main.py.prof` generated in the previous +section: ```bash pprof -http=0.0.0.0:3213 `which python` ./main.py.prof ``` -这行命令中,`-http`指开启HTTP服务。`which python`会产生当前Python二进制的完整路径,进而指定了Python可执行文件的路径。`./main.py.prof`输入了性能分析结果。 - -访问对应的网址,我们可以查看性能分析的结果。结果如下图所示: +Where `-http` specifies the IP and port of the HTTP service. +Directing our Web browser to the service, we would see something like +the following: ![result](./pprof_1.png) +### Identifying the Performance Bottlenecks -### 寻找性能瓶颈 - -与寻找Python代码的性能瓶颈类似,寻找Python与C++混合代码的性能瓶颈也是要看`tottime`和`cumtime`。而`pprof`展示的调用图也可以帮助我们发现性能中的问题。 - -例如下图中, +Similar to how we work with `cprofilev`, we'd focus on `tottime` and +`cumtime`. ![kernel_perf](./pprof_2.png) -在一次训练中,乘法和乘法梯度的计算占用2%-4%左右的计算时间。而`MomentumOp`占用了17%左右的计算时间。显然,`MomentumOp`的性能有问题。 - -在`pprof`中,对于性能的关键路径都做出了红色标记。先检查关键路径的性能问题,再检查其他部分的性能问题,可以更有次序的完成性能的优化。 - -## 总结 +We can see that the execution time of multiplication and the computing +of the gradient of multiplication takes 2% to 4% of the total running +time, and `MomentumOp` takes about 17%. Obviously, we'd want to +optimize `MomentumOp`. -至此,两种性能分析的方式都介绍完毕了。希望通过这两种性能分析的方式,Paddle的开发人员和使用人员可以有次序的,科学的发现和解决性能问题。 +`pprof` would mark performance critical parts of the program in +red. It's a good idea to follow the hint. diff --git a/doc/howto/optimization/cpu_profiling_cn.md b/doc/howto/optimization/cpu_profiling_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..14eba0e2f34b115f5cd24920b5b1af07ec953d00 --- /dev/null +++ b/doc/howto/optimization/cpu_profiling_cn.md @@ -0,0 +1,155 @@ +此教程会介绍如何使用Python的cProfile包、Python库yep、Google perftools来进行性能分析 (profiling) 与调优(performance tuning)。 + +Profling 指发现性能瓶颈。系统中的瓶颈可能和程序员开发过程中想象的瓶颈相去甚远。Tuning 指消除瓶颈。性能优化的过程通常是不断重复地 profiling 和 tuning。 + +PaddlePaddle 用户一般通过调用 Python API 编写深度学习程序。大部分 Python API 调用用 C++ 写的 libpaddle.so。所以 PaddlePaddle 的性能分析与调优分为两个部分: + +* Python 代码的性能分析 +* Python 与 C++ 混合代码的性能分析 + + +## Python代码的性能分析 + +### 生成性能分析文件 + +Python标准库中提供了性能分析的工具包,[cProfile](https://docs.python.org/2/library/profile.html)。生成Python性能分析的命令如下: + +```bash +python -m cProfile -o profile.out main.py +``` + +其中 `main.py` 是我们要分析的程序,`-o`标识了一个输出的文件名,用来存储本次性能分析的结果。如果不指定这个文件,`cProfile`会打印到标准输出。 + +### 查看性能分析文件 + +`cProfile` 在main.py 运行完毕后输出`profile.out`。我们可以使用[`cprofilev`](https://github.com/ymichael/cprofilev)来查看性能分析结果。`cprofilev`是一个Python的第三方库。使用它会开启一个HTTP服务,将性能分析结果以网页的形式展示出来: + +```bash +cprofilev -a 0.0.0.0 -p 3214 -f profile.out main.py +``` + +其中`-a`标识HTTP服务绑定的IP。使用`0.0.0.0`允许外网访问这个HTTP服务。`-p`标识HTTP服务的端口。`-f`标识性能分析的结果文件。`main.py`标识被性能分析的源文件。 + +用Web浏览器访问对应网址,即可显示性能分析的结果: + +``` + ncalls tottime percall cumtime percall filename:lineno(function) + 1 0.284 0.284 29.514 29.514 main.py:1() + 4696 0.128 0.000 15.748 0.003 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/executor.py:20(run) + 4696 12.040 0.003 12.040 0.003 {built-in method run} + 1 0.144 0.144 6.534 6.534 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/__init__.py:14() +``` + +每一列的含义是: + +| 列名 | 含义 | +| --- | --- | +| ncalls | 函数的调用次数 | +| tottime | 函数实际使用的总时间。该时间去除掉本函数调用其他函数的时间 | +| percall | tottime的每次调用平均时间 | +| cumtime | 函数总时间。包含这个函数调用其他函数的时间 | +| percall | cumtime的每次调用平均时间 | +| filename:lineno(function) | 文件名, 行号,函数名 | + + +### 寻找性能瓶颈 + +通常`tottime`和`cumtime`是寻找瓶颈的关键指标。这两个指标代表了某一个函数真实的运行时间。 + +将性能分析结果按照tottime排序,效果如下: + +```text + 4696 12.040 0.003 12.040 0.003 {built-in method run} + 300005 0.874 0.000 1.681 0.000 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/dataset/mnist.py:38(reader) + 107991 0.676 0.000 1.519 0.000 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:219(__init__) + 4697 0.626 0.000 2.291 0.000 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:428(sync_with_cpp) + 1 0.618 0.618 0.618 0.618 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/__init__.py:1() +``` + +可以看到最耗时的函数是C++端的`run`函数。这需要联合我们第二节`Python`与`C++`混合代码的性能分析来进行调优。而`sync_with_cpp`函数的总共耗时很长,每次调用的耗时也很长。于是我们可以点击`sync_with_cpp`的详细信息,了解其调用关系。 + +```text +Called By: + + Ordered by: internal time + List reduced from 4497 to 2 due to restriction <'sync_with_cpp'> + +Function was called by... + ncalls tottime cumtime +/home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:428(sync_with_cpp) <- 4697 0.626 2.291 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:562(sync_with_cpp) +/home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:562(sync_with_cpp) <- 4696 0.019 2.316 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:487(clone) + 1 0.000 0.001 /home/yuyang/perf_test/.env/lib/python2.7/site-packages/paddle/v2/fluid/framework.py:534(append_backward) + + +Called: + + Ordered by: internal time + List reduced from 4497 to 2 due to restriction <'sync_with_cpp'> +``` + +通常观察热点函数间的调用关系,和对应行的代码,就可以了解到问题代码在哪里。当我们做出性能修正后,再次进行性能分析(profiling)即可检查我们调优后的修正是否能够改善程序的性能。 + + + +## Python与C++混合代码的性能分析 + +### 生成性能分析文件 + +C++的性能分析工具非常多。常见的包括`gprof`, `valgrind`, `google-perftools`。但是调试Python中使用的动态链接库与直接调试原始二进制相比增加了很多复杂度。幸而Python的一个第三方库`yep`提供了方便的和`google-perftools`交互的方法。于是这里使用`yep`进行Python与C++混合代码的性能分析 + +使用`yep`前需要安装`google-perftools`与`yep`包。ubuntu下安装命令为 + +```bash +apt update +apt install libgoogle-perftools-dev +pip install yep +``` + +安装完毕后,我们可以通过 + +```bash +python -m yep -v main.py +``` + +生成性能分析文件。生成的性能分析文件为`main.py.prof`。 + +命令行中的`-v`指定在生成性能分析文件之后,在命令行显示分析结果。我们可以在命令行中简单的看一下生成效果。因为C++与Python不同,编译时可能会去掉调试信息,运行时也可能因为多线程产生混乱不可读的性能分析结果。为了生成更可读的性能分析结果,可以采取下面几点措施: + +1. 编译时指定`-g`生成调试信息。使用cmake的话,可以将CMAKE_BUILD_TYPE指定为`RelWithDebInfo`。 +2. 编译时一定要开启优化。单纯的`Debug`编译性能会和`-O2`或者`-O3`有非常大的差别。`Debug`模式下的性能测试是没有意义的。 +3. 运行性能分析的时候,先从单线程开始,再开启多线程,进而多机。毕竟单线程调试更容易。可以设置`OMP_NUM_THREADS=1`这个环境变量关闭openmp优化。 + +### 查看性能分析文件 + +在运行完性能分析后,会生成性能分析结果文件。我们可以使用[`pprof`](https://github.com/google/pprof)来显示性能分析结果。注意,这里使用了用`Go`语言重构后的`pprof`,因为这个工具具有web服务界面,且展示效果更好。 + +安装`pprof`的命令和一般的`Go`程序是一样的,其命令如下: + +```bash +go get github.com/google/pprof +``` + +进而我们可以使用如下命令开启一个HTTP服务: + +```bash +pprof -http=0.0.0.0:3213 `which python` ./main.py.prof +``` + +这行命令中,`-http`指开启HTTP服务。`which python`会产生当前Python二进制的完整路径,进而指定了Python可执行文件的路径。`./main.py.prof`输入了性能分析结果。 + +访问对应的网址,我们可以查看性能分析的结果。结果如下图所示: + +![result](./pprof_1.png) + + +### 寻找性能瓶颈 + +与寻找Python代码的性能瓶颈类似,寻找Python与C++混合代码的性能瓶颈也是要看`tottime`和`cumtime`。而`pprof`展示的调用图也可以帮助我们发现性能中的问题。 + +例如下图中, + +![kernel_perf](./pprof_2.png) + +在一次训练中,乘法和乘法梯度的计算占用2%-4%左右的计算时间。而`MomentumOp`占用了17%左右的计算时间。显然,`MomentumOp`的性能有问题。 + +在`pprof`中,对于性能的关键路径都做出了红色标记。先检查关键路径的性能问题,再检查其他部分的性能问题,可以更有次序的完成性能的优化。 diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc index 48cd131550dea5ad3f368b25c31d753efbe0dff9..02a825324328fa5cfd3a4d23a8c64488cc88aeec 100644 --- a/paddle/framework/op_desc.cc +++ b/paddle/framework/op_desc.cc @@ -65,7 +65,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { PADDLE_ENFORCE_EQ(in_var->GetType(), VarDesc::LOD_TENSOR, "The %d-th output of Output(%s) must be LoDTensor.", j, out); - in_var->SetLoDLevel(out_var->GetLodLevel()); + out_var->SetLoDLevel(in_var->GetLodLevel()); } bool IsRuntime() const override; diff --git a/paddle/gserver/tests/CMakeLists.txt b/paddle/gserver/tests/CMakeLists.txt index c295ea19c9ccb3d05c509a41925d2c36efdba8ef..24e6cae8e69557c42ed5d437edce101709ca3983 100644 --- a/paddle/gserver/tests/CMakeLists.txt +++ b/paddle/gserver/tests/CMakeLists.txt @@ -62,11 +62,11 @@ if(NOT WITH_DOUBLE AND NOT MOBILE_INFERENCE) endif() if(NOT MOBILE_INFERENCE) -################## test_Evaluator ####################### + ################## test_Evaluator ####################### add_unittest(test_Evaluator test_Evaluator.cpp) -############### test_RecurrentGradientMachine ############### + ############### test_RecurrentGradientMachine ############### # TODO(yuyang18): There is some bug in test_RecurrentGradientMachine # I will fix it. add_unittest_without_exec(test_RecurrentGradientMachine @@ -77,7 +77,7 @@ if(NOT MOBILE_INFERENCE) ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) -############### test_NetworkCompare ############### + ############### test_NetworkCompare ############### add_unittest_without_exec(test_NetworkCompare test_NetworkCompare.cpp) if(WITH_GPU) @@ -89,34 +89,33 @@ if(NOT MOBILE_INFERENCE) COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_NetworkCompare --use_gpu=false WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle) endif() -endif() + ################# test_CompareSparse ################## + add_unittest_without_exec(test_CompareSparse + test_CompareSparse.cpp) + if(NOT ON_TRAVIS) + add_test(NAME test_CompareSparse + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d + ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests + ./.set_port.sh -p port -n 6 + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) + endif() + + ################ test_CompareTwoNets ###################### + add_unittest_without_exec(test_CompareTwoNets + test_CompareTwoNets.cpp) + add_test(NAME test_CompareTwoNets + COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d + ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests + ${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoNets + WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) +endif() +################ test_PyDataProvider2 ###################### add_unittest_without_exec(test_PyDataProvider2 test_PyDataProvider2.cpp) - add_test(NAME test_PyDataProvider2 COMMAND .set_python_path.sh -d ${PADDLE_SOURCE_DIR}/paddle/gserver/tests:${PADDLE_SOURCE_DIR}/python ${CMAKE_CURRENT_BINARY_DIR}/test_PyDataProvider2 WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle ) - -################# test_CompareSparse ################## -add_unittest_without_exec(test_CompareSparse - test_CompareSparse.cpp) -if(NOT ON_TRAVIS) - add_test(NAME test_CompareSparse - COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ./.set_port.sh -p port -n 6 - ${CMAKE_CURRENT_BINARY_DIR}/test_CompareSparse - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) -endif() - -################ test_CompareTwoNets ###################### -add_unittest_without_exec(test_CompareTwoNets - test_CompareTwoNets.cpp) -add_test(NAME test_CompareTwoNets - COMMAND ${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d - ${PADDLE_SOURCE_DIR}/python:${PADDLE_SOURCE_DIR}/paddle/gserver/tests - ${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoNets - WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/paddle/) diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 5eb1c44eb6fc45db31ef44bf79e74b79193e08aa..95cfe2525e3e7c128d8652c5c6a0bb3d80a475b9 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -81,18 +81,33 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { } template <> -void* Alloc(platform::GPUPlace place, size_t size) { - return GetGPUBuddyAllocator(place.device)->Alloc(size); +size_t Used(platform::GPUPlace place) { + return GetGPUBuddyAllocator(place.device)->Used(); } template <> -void Free(platform::GPUPlace place, void* p) { - GetGPUBuddyAllocator(place.device)->Free(p); +void* Alloc(platform::GPUPlace place, size_t size) { + auto* buddy_allocator = GetGPUBuddyAllocator(place.device); + auto* ptr = buddy_allocator->Alloc(size); + if (ptr == nullptr) { + int cur_dev = platform::GetCurrentDeviceId(); + platform::SetDeviceId(place.device); + size_t avail, total; + platform::GpuMemoryUsage(avail, total); + LOG(WARNING) << "Cannot allocate " << size << " bytes in GPU " + << place.device << ", available " << avail << " bytes"; + LOG(WARNING) << "total " << total; + LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize(); + LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize(); + LOG(WARNING) << "GPU memory used: " << Used(place); + platform::SetDeviceId(cur_dev); + } + return ptr; } template <> -size_t Used(platform::GPUPlace place) { - return GetGPUBuddyAllocator(place.device)->Used(); +void Free(platform::GPUPlace place, void* p) { + GetGPUBuddyAllocator(place.device)->Free(p); } #endif diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7e5d4fd640f4399d1a217d1a0be76b3da457c0cc..937441b318095eadb9022c1d7578ad8aca2dadc8 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -191,6 +191,7 @@ set(DEPS_OPS sum_op pool_op maxout_op + unpool_op pool_with_index_op conv_op conv_transpose_op @@ -235,6 +236,7 @@ op_library(adagrad_op DEPS selected_rows_functor) op_library(conv_op DEPS vol2col) op_library(pool_op DEPS pooling) op_library(maxout_op DEPS maxouting) +op_library(unpool_op DEPS unpooling) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) op_library(lod_tensor_to_array_op SRCS lod_tensor_to_array_op.cc DEPS lod_rank_table_op) diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index 7a36a9b21aa6a1b415ac5a232e65eda8051c87f8..462e6d9cbcbe61d9911efe8beff4446620e1e932 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -97,7 +97,7 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, .SetDefault({0, 0}); AddAttr( "groups", - "(int default:1), the group size of convolution operator. " + "(int default:1), the groups number of the convolution operator. " "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " "when group=2, the first half of the filters is only connected to the " "first half of the input channels, while the second half of the filters " @@ -112,23 +112,29 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, Convolution Operator. The convolution operation calculates the output based on the input, filter -and strides, paddings, groups, dilations parameters. The size of each dimension of the +and strides, paddings, dilations, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch +Input(Input) and Output(Output) are in NCHW format. Where N is batch size, C is the number of channels, H is the height of the feature, and W is -the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements. -These two elements represent height and width, respectively. +the width of the feature. +Filters(Input) is MCHW format. Where M is the number of output image channels, C is +the number of input image channels, H is the height of the filter, and W +is the width of the filter. +Parameters(strides, paddings, dilations) are two elements. These two elements represent +height and width, respectively. The input(X) size and output(Out) size may be different. Example: Input: - Input shape: (N, C_in, H_in, W_in) - Filter shape: (C_out, C_in, H_f, W_f) + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, H_f, W_f)$ Output: - Output shape: (N, C_out, H_out, W_out) - where - H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1; - W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1; + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where +$$ + H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 +$$ )DOC"); } @@ -165,7 +171,7 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, .SetDefault({0, 0, 0}); AddAttr( "groups", - "(int default:1), the group size of convolution operator. " + "(int default:1), the groups number of the convolution operator. " "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " "when group=2, the first half of the filters is only connected to the " "first half of the input channels, while the second half of the filters " @@ -174,32 +180,37 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, AddAttr>("dilations", "(vector default:{1, 1, 1}), the " "dilations(d_dilation, h_dilation, w_dilation) of " - "convolution operator. Currently, conv3d doesn't " - "support dilation.") + "convolution operator.") .SetDefault({1, 1, 1}); AddComment(R"DOC( Convolution3D Operator. The convolution operation calculates the output based on the input, filter -and strides, paddings, groups parameters. The size of each dimension of the +and strides, paddings, dilations, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. -Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch +Input(Input) and output(Output) are in NCDHW format, where N is batch size, C is the number of channels,D is the depth of the feature, H is the height of -the feature, and W is the width of the feature. Parameters(ksize, strides, paddings) -are three elements. These three elements represent depth, height and width, respectively. +the feature, and W is the width of the feature. +Filters(Input) is MCDHW format, where M is the number of output image channels, +C is the number of input image channels, D is the depth of the filter, +H is the height of the filter, and W is the width of the filter. +Parameters(strides, paddings, dilations) are three elements. These three elements +represent depth, height and width, respectively. The input(X) size and output(Out) size may be different. Example: Input: - Input shape: (N, C_in, D_in, H_in, W_in) - Filter shape: (C_out, C_in, D_f, H_f, W_f) + Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, D_f, H_f, W_f)$ Output: - Output shape: (N, C_out, D_out, H_out, W_out) - where - D_out = (D_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; - H_out = (H_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; - W_out = (W_in - filter_size[2] + 2 * paddings[2]) / strides[2] + 1; + Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$ + Where + $$ + D_{out}= \frac{(D_{in} + 2 * paddings[0] - (dilations[0] * (D_f - 1) + 1))}{ strides[0]}+ 1 \\ + H_{out}= \frac{(H_{in} + 2 * paddings[1] - (dilations[1] * (H_f - 1) + 1))}{ strides[1]}+ 1 \\ + W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1 + $$ )DOC"); } diff --git a/paddle/operators/conv_transpose_op.cc b/paddle/operators/conv_transpose_op.cc index 314b577d0067773028a4e9b333a41b8e7a74d327..678b192dea78fc6b4a6b54c4bb09a55dfb8f9c38 100644 --- a/paddle/operators/conv_transpose_op.cc +++ b/paddle/operators/conv_transpose_op.cc @@ -39,7 +39,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "ConvTransposeOp input dimension and strides dimension should " "be consistent."); PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), - "ConvTransposeOp paddings dimension and Conv strides " + "ConvTransposeOp paddings dimension and strides " "dimension should be the same."); PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], "In ConvTransposeOp, The input channel should be the same " @@ -62,13 +62,14 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( "The format of input tensor is NCHW. Where N is batch size, C is the " "number of input channels, H is the height of the feature, and " "W is the width of the feature."); - AddInput("Filter", - "(Tensor) The filter tensor of convolution transpose operator. " - "The format of the filter tensor is CMHW, where C is the number of " - "output image channels, M is the number of input image channels, " - "H is the height of the filter, and W is the width of the filter. " - "We enforce groups number == 1 and padding == 0 in " - "the convolution transpose scenario."); + AddInput( + "Filter", + "(Tensor) The filter tensor of convolution transpose operator. " + "The format of the filter tensor is MCHW, where M is the number of " + "input feature channels, C is the number of " + "output feature channels," + "H is the height of the filter, and W is the width of the filter. " + "We enforce groups number == 1 in the convolution transpose scenario."); AddOutput("Output", "(Tensor) The output tensor of convolution transpose operator. " "The format of output tensor is also NCHW."); @@ -88,21 +89,26 @@ Convolution2D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. - -Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch -size, C is the number of channels, H is the height of the feature, and -W is the width of the feature. Parameters(ksize, strides, paddings) are two elements. -These two elements represent height and width, respectively. +Input(Input) and output(Output) are in NCHW format. Where N is batchsize, C is the +number of channels, H is the height of the feature, and W is the width of the feature. +Filter(Input) is in MCHW format. Where M is the number of input feature channels, +C is the number of output feature channels, H is the height of the filter, +and W is the width of the filter. +Parameters(strides, paddings) are two elements. These two elements represent height +and width, respectively. The input(X) size and output(Out) size may be different. + Example: Input: - Input shape: (N, C_in, H_in, W_in) - Filter shape: (C_in, C_out, H_f, W_f) + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{in}, C_{out}, H_f, W_f)$ Output: - Output shape: (N, C_out, H_out, W_out) - where - H_out = (H_in - 1) * strides[0] - 2 * paddings[0] + H_f; - W_out = (W_in - 1) * strides[1] - 2 * paddings[1] + W_f; + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where + $$ + H_{out} = (H_{in} - 1) * strides[0] - 2 * paddings[0] + H_f \\ + W_{out} = (W_{in} - 1) * strides[1] - 2 * paddings[1] + W_f + $$ )DOC"); } @@ -117,8 +123,9 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( "W is the width of the feature."); AddInput("Filter", "(Tensor) The filter tensor of convolution transpose operator." - "The format of the filter tensor is CMDHW, where C is the number of " - "output image channels, M is the number of input image channels, D " + "The format of the filter tensor is MCDHW, where M is the number of " + "input feature channels, C is the number of " + "output feature channels, D " "is the depth of the filter, H is the height of the filter, and " "W is the width of the filter." "We enforce groups number == 1 and padding == 0 in " @@ -144,23 +151,28 @@ Convolution3D Transpose Operator. The convolution transpose operation calculates the output based on the input, filter and strides, paddings, groups parameters. The size of each dimension of the parameters is checked in the infer-shape. - -Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch -size, C is the number of channels, D is the depth of the feature, -H is the height of the feature, and W is the width of the feature. -Parameters(ksize, strides, paddings) are three elements. -These three elements represent depth, height and width, respectively. +Input(Input) and output(Output) are in NCDHW format. Where N is batch size, C is the +number of channels, D is the depth of the feature, H is the height of the feature, +and W is the width of the feature. +Filter(Input) is in MCDHW format. Where M is the number of input feature channels, +C is the number of output feature channels, D is the depth of the filter,H is the +height of the filter, and W is the width of the filter. +Parameters(strides, paddings) are three elements. These three elements represent +depth, height and width, respectively. The input(X) size and output(Out) size may be different. -Example: + +Example: Input: - Input shape: (N, C_in, D_in, H_in, W_in) - Filter shape: (C_in, C_out, D_f, H_f, W_f) + Input shape: $(N, C_{in}, D_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{in}, C_{out}, D_f, H_f, W_f)$ Output: - Output shape: (N, C_out, D_out, H_out, W_out) - where - D_out = (D_in - 1) * strides[0] - 2 * paddings[0] + D_f; - H_out = (H_in - 1) * strides[1] - 2 * paddings[1] + H_f; - W_out = (W_in - 1) * strides[2] - 2 * paddings[2] + W_f; + Output shape: $(N, C_{out}, D_{out}, H_{out}, W_{out})$ + Where + $$ + D_{out} = (D_{in} - 1) * strides[0] - 2 * paddings[0] + D_f \\ + H_{out} = (H_{in} - 1) * strides[1] - 2 * paddings[1] + H_f \\ + W_{out} = (W_{in} - 1) * strides[2] - 2 * paddings[2] + W_f + $$ )DOC"); } diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index 0fc0735788c499c2d520c0cc689e1ce07ba67ce8..1cacb770e6af3ad3c99ab81c5598ffcd228f59b2 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -63,7 +63,6 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - // TODO(Zhuoyuan): Paddings can be added in future. // groups will alway be disabled in conv2dtranspose. const int batch_size = static_cast(input->dims()[0]); diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index 962c7d59819dede022474aec4a2d7f538d28c688..07ff9d2c621a2dfb51792821a0d3fc398c315835 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -32,4 +32,4 @@ message VariableMessage { bytes serialized = 2; } -message VoidMessage {} \ No newline at end of file +message VoidMessage {} diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 3017f133afc5d4dcd484c78b44591a876ab4d667..bf47879f772a3013bd7ce78c6f8a6aefe65298f9 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -13,8 +13,9 @@ if(WITH_GPU) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) - nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) + nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) + nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -26,8 +27,9 @@ else() cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) - cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) cc_library(maxouting SRCS maxouting.cc DEPS device_context) + cc_library(unpooling SRCS unpooling.cc DEPS device_context) + cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/unpooling.cc b/paddle/operators/math/unpooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..b57d3dc1414cff492db8d7d503a7fce370a3f151 --- /dev/null +++ b/paddle/operators/math/unpooling.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/unpooling.h" +namespace paddle { +namespace operators { +namespace math { +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + output_data[index] = input_data[i]; + } + input_data += input_feasize; + indices_data += input_feasize; + output_data += output_feasize; + } + } + } +}; +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + int input_feasize = input_height * input_width; + int output_feasize = output_height * output_width; + const int* indices_data = indices.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + + for (int b = 0; b < batch_size; ++b) { + for (int c = 0; c < output_channels; ++c) { + for (int i = 0; i < input_feasize; ++i) { + int index = indices_data[i]; + PADDLE_ENFORCE(index < output_feasize, "err index in unpooling!"); + input_grad_data[i] = output_grad_data[index]; + } + input_grad_data += input_feasize; + indices_data += input_feasize; + output_grad_data += output_feasize; + } + } + } +}; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.cu b/paddle/operators/math/unpooling.cu new file mode 100644 index 0000000000000000000000000000000000000000..37c3c8b689f9a69b68ddffd23813fa9ad8ced0e7 --- /dev/null +++ b/paddle/operators/math/unpooling.cu @@ -0,0 +1,134 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/unpooling.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { +template +__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, + const int* indices_data, + const int input_height, const int input_width, + const int channels, T* output_data, + const int output_height, + const int output_width) { + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; + int out_index = indices_data[i]; + PADDLE_ASSERT(out_index < out_c_stride); + output_data[out_offset + out_index] = input_data[i]; + } +} +template +__global__ void KernelUnpool2dMaxGrad( + const int nthreads, const T* input_data, const int* indices_data, + const int input_height, const int input_width, const int channels, + const T* output_data, const T* output_grad, const int output_height, + const int output_width, T* input_grad) { + int in_n_stride = input_height * input_width * channels; + int in_c_stride = input_height * input_width; + int out_n_stride = output_height * output_width * channels; + int out_c_stride = output_height * output_width; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int bidx = i / in_n_stride; + int boffset = i % in_n_stride; + int cidx = boffset / in_c_stride; + int out_offset = bidx * out_n_stride + cidx * out_c_stride; + int out_index = indices_data[i]; + PADDLE_ASSERT(out_index < out_c_stride); + input_grad[i] = output_grad[out_offset + out_index]; + } +} +/* + * All tensors are in NCHW format. + */ +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + T* output_data = output->mutable_data(context.GetPlace()); + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool2dMax< + T><<(context) + .stream()>>>(input.numel(), input_data, indices_data, + input_height, input_width, output_channels, + output_data, output_height, output_width); + } +}; +/* + * All tensors are in NCHW format. + */ +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + const T* input_data = input.data(); + const int* indices_data = indices.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); + int threads = 1024; + int grid = (input.numel() + threads - 1) / threads; + KernelUnpool2dMaxGrad< + T><<(context) + .stream()>>>(input.numel(), input_data, indices_data, + input_height, input_width, output_channels, + output_data, output_grad_data, output_height, + output_width, input_grad_data); + } +}; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxGradFunctor; +template class Unpool2dMaxFunctor; +template class Unpool2dMaxFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/unpooling.h b/paddle/operators/math/unpooling.h new file mode 100644 index 0000000000000000000000000000000000000000..7077d7c2274fd9e02b69ef343f310f4ffbbcff1a --- /dev/null +++ b/paddle/operators/math/unpooling.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { +template +class Unpool2dMaxFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, framework::Tensor* output); +}; +template +class Unpool2dMaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + const framework::Tensor& indices, + const framework::Tensor& output, + const framework::Tensor& output_grad, + framework::Tensor* input_grad); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index d8c58618cf703d086d3cabc927ebc5eb038b1aec..e26ffd86e5b5645e361070ca9fd9d8dc49d1ed30 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -105,7 +105,7 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "(vector, defalut {0,0}), paddings(height, width) of pooling " + "(vector, default {0,0}), paddings(height, width) of pooling " "operator." "If global_pooling = true, paddings and ksize will be ignored.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, @@ -122,15 +122,15 @@ Parameters(ksize, strides, paddings) are two elements. These two elements represent height and width, respectively. The input(X) size and output(Out) size may be different. -Example: +Example: Input: X shape: $(N, C, H_{in}, W_{in})$ Output: Out shape: $(N, C, H_{out}, W_{out})$ - where + Where $$ - H_{out} = (H_{in} - ksize[0] + 2 * paddings[0]) / strides[0] + 1 \\ - W_{out} = (W_{in} - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ + W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 $$ )DOC"); @@ -177,7 +177,7 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "(vector, defalut {0,0,0}), paddings(depth, height, " + "(vector, default {0,0,0}), paddings(depth, height, " "width) of pooling operator. " "If global_pooling = true, ksize and paddings will be ignored.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, @@ -199,12 +199,12 @@ Example: X shape: $(N, C, D_{in}, H_{in}, W_{in})$ Output: Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ - where - $$ - D_{out} = (D_{in} - ksize[0] + 2 * paddings[0]) / strides[0] + 1 \\ - H_{out} = (H_{in} - ksize[1] + 2 * paddings[1]) / strides[1] + 1 \\ - W_{out} = (W_{in} - ksize[2] + 2 * paddings[2]) / strides[2] + 1 - $$ + Where + $$ + D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ + H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\ + W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 + $$ )DOC"); } diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 4958fa645405db0798f37165030eae95da371477..b9c42a69128a26ff5942748e11fb87c57d3e3f58 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -142,7 +142,7 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "(vector, defalut:{0, 0}), paddings(height, width) of pooling " + "(vector, default:{0, 0}), paddings(height, width) of pooling " "operator. " "If global_pooling = true, paddings and will be ignored.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, @@ -166,10 +166,10 @@ Example: Output: Out shape: $(N, C, H_{out}, W_{out})$ Mask shape: $(N, C, H_{out}, W_{out})$ - where + Where $$ - H_{out} = (H_{in} - ksize[0] + 2 * paddings[0]) / strides[0] + 1 \\ - W_{out} = (W_{in} - ksize[1] + 2 * paddings[1]) / strides[1] + 1 + H_{out} = \frac{(H_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ + W_{out} = \frac{(W_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 $$ )DOC"); @@ -220,7 +220,7 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "(vector, defalut {0,0,0}), paddings(depth, " + "(vector, default {0,0,0}), paddings(depth, " "height, width) of pooling operator. " "If global_pooling = true, paddings and ksize will be ignored.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, @@ -244,11 +244,11 @@ Example: Output: Out shape: $(N, C, D_{out}, H_{out}, W_{out})$ Mask shape: $(N, C, D_{out}, H_{out}, W_{out})$ - where + Where $$ - D_{out} = (D_{in} - ksize[0] + 2 * paddings[0]) / strides[0] + 1 \\ - H_{out} = (H_{in} - ksize[1] + 2 * paddings[1]) / strides[1] + 1 \\ - W_{out} = (W_{in} - ksize[2] + 2 * paddings[2]) / strides[2] + 1 + D_{out} = \frac{(D_{in} - ksize[0] + 2 * paddings[0])}{strides[0]} + 1 \\ + H_{out} = \frac{(H_{in} - ksize[1] + 2 * paddings[1])}{strides[1]} + 1 \\ + W_{out} = \frac{(W_{in} - ksize[2] + 2 * paddings[2])}{strides[2]} + 1 $$ )DOC"); diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 061e82412ea5f4f17fd26a7094e68b97138cc09c..912f88f455252effbdb12ecfc45e4afefa60e03e 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel { auto right_dims = ctx->GetInputDim("Right"); PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), - "All inputs must have the same size"); - PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), - "All inputs must be row vector with size batch_size x 1."); + "All inputs must have the same size."); + PADDLE_ENFORCE( + (label_dims.size() == 2) && (label_dims[1] == 1), + "All inputs must be 2-D tensors with shape [batch_size x 1]."); ctx->SetOutputDim("Out", label_dims); } }; @@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Label", - "The label indicating A ranked higher than B or not, row vector."); - AddInput("Left", "The output of RankNet for doc A, vector."); - AddInput("Right", "The output of RankNet for doc B, vetor."); - AddOutput("Out", "The output loss of RankLoss operator, vector."); + "(2-D Tensor with shape [batch_size x 1]) " + "The label indicating A ranked higher than B or not."); + AddInput("Left", + "(2-D Tensor with shape [batch_size x 1]) " + "The output of RankNet for doc A."); + AddInput("Right", + "(2-D Tensor with shape [batch_size x 1]) " + "The output of RankNet for doc B."); + AddOutput("Out", + "(2-D Tensor with shape [batch_size x 1]) " + "The output loss of RankLoss operator."); AddComment(R"DOC( RankLoss Operator. @@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of the input pair. The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label -(P_{i,j}), which represent the output of RankNet for the two docs and the label, -respectively, and yields the rank loss C_{i,j} using the following equation: +(P_{i,j}), which represent the output score of RankNet for the two docs and +the label respectively, and yields the rank loss C_{i,j} using the following +equation: -\f$$ - C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ +$$ + C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \log(1 + e^{o_{i,j}}) \\ o_{i,j} = o_i - o_j \\ \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} -\f$$ +$$ -The operator can take inputs of one sample or in batch. +The operator can take batch inputs with size batch_size (batch_size >= 1). )DOC"); } diff --git a/paddle/operators/rank_loss_op.cu b/paddle/operators/rank_loss_op.cu index 779588ff36c792b8925a535d60f1cfbbe3c66d86..5382e3a6296acd257211104d8ec6835c11b90bdd 100644 --- a/paddle/operators/rank_loss_op.cu +++ b/paddle/operators/rank_loss_op.cu @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/operators/rank_loss_op.h b/paddle/operators/rank_loss_op.h index f184d6efcb496a1d7f38540712b6c431f816482e..703c77a0b21f2b2f0b0ae6fae86aae819ea824b5 100644 --- a/paddle/operators/rank_loss_op.h +++ b/paddle/operators/rank_loss_op.h @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index ba774ec2160c0460867de42f7ad9d5cd65ad8d6a..39bf2118d603881531bf583ae468e8dc9b8bd181 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -1,11 +1,10 @@ - /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -38,8 +37,8 @@ class ReshapeOp : public framework::OperatorWithKernel { // TODO(qiao) change batch_size for (size_t i = 1; i < shape.size(); ++i) { PADDLE_ENFORCE(shape[i] > 0, - "Each dimension of shape " - "must be positiv except the first."); + "Each dimension of Attr(shape) " + "must be positive except the first one."); } if (shape[0] < 0) { shape[0] = x_dims[0]; diff --git a/paddle/operators/reshape_op.cu.cc b/paddle/operators/reshape_op.cu similarity index 94% rename from paddle/operators/reshape_op.cu.cc rename to paddle/operators/reshape_op.cu index 23dbe089d3b37aabedf9ef166f7bbfbf67da7e0a..dca6c15007a64808248443af32141b4a677f95d7 100644 --- a/paddle/operators/reshape_op.cu.cc +++ b/paddle/operators/reshape_op.cu @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 0e98c8b4f443f88ecba044f2f79228227695e182..73fd1da6428f55976a397b7f6f92bb0c796bfe02 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -4,7 +4,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index ebf7b43700a7498aa18b5f648b0b8c2c4e7b442b..50543fcc148698c42e15259ba20bdacdd50ac1af 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -22,22 +22,20 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x_dims, y_dims, "The shape of X and Y must be the same."); + PADDLE_ENFORCE_EQ(x_dims, y_dims); PADDLE_ENFORCE_GE(x_dims.size(), 2, - "The tensor rank of X must be at least 2."); + "The tensor rank of Input(X) should not be less than 2."); if (ctx->HasInput("InsideWeight")) { PADDLE_ENFORCE(ctx->HasInput("OutsideWeight"), "If weights are provided, must specify both " "inside and outside weights."); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims, - "The shape of InsideWeight must be same as X."); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims, - "The shape of OutsideWeight must be same as X."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("InsideWeight"), x_dims); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("OutsideWeight"), x_dims); } ctx->SetOutputDim("Diff", x_dims); @@ -53,25 +51,29 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "The input tensor of smooth l1 loss op." - "The rank should be greater or equal to 2 with shape " - "[batch_size, value_dim1, value_dim2, ..., value_dimN]"); + "(Tensor, default Tensor) A tensor with rank at least 2. " + "The input value of smooth l1 loss op with shape " + "[batch_size, dim1, ..., dimN]."); AddInput("Y", - "The target tensor of smooth l1 loss op " - "with the same shape as X."); + "(Tensor, default Tensor) A tensor with rank at least 2. " + "The target value of smooth l1 loss op with same shape as X."); AddInput("InsideWeight", - "Optional input tensor of smooth l1 loss op with the same shape " - "as X. If provided, the result of (X - Y) will be multiplied " + "(Tensor, default Tensor) A tensor with rank at least 2. " + "This input is optional and should have same shape with X. " + "If provided, the result of (X - Y) will be multiplied " "by this tensor element by element.") .AsDispensable(); AddInput("OutsideWeight", - "Optinal input of smooth l1 loss op with the same shape as X." - "If provided, the output smooth l1 loss will be multiplied by " - "this tensor element by element.") + "(Tensor, default Tensor) A tensor with rank at least 2. " + "This input is optional and should have same shape with X. " + "If provided, the out smooth l1 loss will be multiplied by this " + "tensor element by element.") .AsDispensable(); - AddOutput("Diff", "Intermediate variable to cache InsideWeight*(X-Y).") + AddOutput("Diff", "Intermediate variable to cache InsideWeight * (X - Y).") .AsIntermediate(); - AddOutput("Out", "Smooth l1 loss."); + AddOutput("Out", + "(Tensor, default Tensor) A tensor with rank be 2. " + "The output smooth l1 loss with shape [batch_size, 1]."); AddAttr("sigma", "Hyper parameter of smooth l1 loss op." "A float scalar with default value 3.0.") @@ -79,15 +81,23 @@ class SmoothL1LossOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Smooth L1 Loss Operator. -This operator computes the smooth l1 loss for input and target. -The operator takes the first dimension of input as the batch size. +This operator computes the smooth l1 loss for X and Y. +The operator takes the first dimension of X and Y as batch size. For each instance, it computes the smooth l1 loss element by element first -and then sums all the losses. So the resulting output shape -is [batch_size, 1]. +and then sums all the losses. So the shape of Out is [batch_size, 1]. The equation is: -loss = $$0.5 * (\sigma * (x-y))^2$$ if $$|x - y| < 1 /({\sigma}^2)$$ - $$\frac{|x - y| - 0.5}{{\sigma}^2}$$ otherwise +$$ +Out_{\sigma}(X, Y)_i = \begin{cases} +0.5 * (\sigma * (X_i - Y_i)) ^ 2 +\quad |X_i - Y_i| \lt \frac{1} {{\sigma} ^ 2} \\ +\frac{|X_i - Y_i| - 0.5}{{\sigma}^2}, +\quad otherwise +\end{cases} +$$ + +In the above equation, $Out_{\sigma}(X, Y)_i$, $X_i$ and $Y_i$ represent the ith +element of Out, X and Y. )DOC"); } diff --git a/paddle/operators/unpool_op.cc b/paddle/operators/unpool_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..89c48e071cf351f7d7b9cf26a5d4989af291da57 --- /dev/null +++ b/paddle/operators/unpool_op.cc @@ -0,0 +1,143 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/unpool_op.h" +namespace paddle { +namespace operators { + +class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker { + public: + Unpool2dOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of unpool operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddInput( + "Indices", + "(Tensor) The input tensor of the indices given out by MaxPool2d. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of unpool operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + AddAttr>( + "ksize", + "(vector), the unpooling window size(height, width) " + "of unpooling operator."); + AddAttr>("strides", + "(vector, default:{1, 1}), " + "strides (height, width) of unpooling operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector defalut:{0,0}), " + "paddings (height, width) of unpooling operator.") + .SetDefault({0, 0}); + AddAttr( + "unpooling_type", + "(string), unpooling type, can be \"max\" for max-unpooling ") + .InEnum({"max"}); + AddComment(R"DOC( + "Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where + $$ + H_{out} = (H_{in}−1) * strides[0] − 2 * paddings[0] + ksize[0] \\ + W_{out} = (W_{in}−1) * strides[1] − 2 * paddings[1] + ksize[1] + $$ + Paper: http://www.matthewzeiler.com/wp-content/uploads/2017 + /07/iccv2011.pdf + )DOC"); + } +}; + +int OutputSize(int input_size, int ksize, int padding, int stride) { + int output_size = (input_size - 1) * stride - 2 * padding + ksize; + return output_size; +} + +class UnpoolOp : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Indices"), + "Input(Indices) of UnpoolOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnpoolOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + auto in_y_dims = ctx->GetInputDim("Indices"); + std::string unpooling_type = + ctx->Attrs().Get("unpooling_type"); + std::vector ksize = ctx->Attrs().Get>("ksize"); + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + PADDLE_ENFORCE(in_x_dims.size() == 4, + "Unpooling intput must be of 4-dimensional."); + PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims); + std::vector output_shape({in_x_dims[0], in_x_dims[1]}); + for (size_t i = 0; i < ksize.size(); ++i) { + output_shape.push_back( + OutputSize(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); + } + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class UnpoolOpGrad : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker, unpool_grad, + ops::UnpoolOpGrad); +REGISTER_OP_CPU_KERNEL(unpool, + ops::UnpoolKernel, + ops::UnpoolKernel); +REGISTER_OP_CPU_KERNEL( + unpool_grad, ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.cu.cc b/paddle/operators/unpool_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..18aafb7dc74ed474ed3ec5e8a388ecdb71b9a8f5 --- /dev/null +++ b/paddle/operators/unpool_op.cu.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/unpool_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(unpool, + ops::UnpoolKernel, + ops::UnpoolKernel); +REGISTER_OP_GPU_KERNEL( + unpool_grad, ops::UnpoolGradKernel, + ops::UnpoolGradKernel); diff --git a/paddle/operators/unpool_op.h b/paddle/operators/unpool_op.h new file mode 100644 index 0000000000000000000000000000000000000000..243eb7e532c5149db4fb1b381fd8664ae4bdd81a --- /dev/null +++ b/paddle/operators/unpool_op.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/unpooling.h" + +namespace paddle { +namespace operators { +template +class UnpoolKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + auto* out = context.Output("Out"); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + T* output_data = out->mutable_data(context.GetPlace()); + if (output_data) { + math::SetConstant set_zero; + set_zero(context.device_context(), out, static_cast(0)); + } + math::Unpool2dMaxFunctor unpool2d_max_forward; + unpool2d_max_forward(context.device_context(), *in_x, *in_y, out); + } +}; +template +class UnpoolGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* in_y = context.Input("Indices"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + std::string unpooling_type = context.Attr("unpooling_type"); + std::vector ksize = context.Attr>("ksize"); + std::vector strides = context.Attr>("strides"); + std::vector paddings = context.Attr>("paddings"); + + auto& device_ctx = context.device_context(); + math::SetConstant zero; + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0)); + } + math::Unpool2dMaxGradFunctor unpool2d_max_backward; + unpool2d_max_backward(context.device_context(), *in_x, *in_y, *out, + *out_grad, in_x_grad); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/optimizer/parameter_optimizer_test.cc b/paddle/optimizer/parameter_optimizer_test.cc index f29e5317120642e3790a6f6c1976bdda67093a0c..83757a391784453341f22eca73bc73c14ce4174f 100644 --- a/paddle/optimizer/parameter_optimizer_test.cc +++ b/paddle/optimizer/parameter_optimizer_test.cc @@ -127,8 +127,3 @@ TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } TEST_F(OptimizerTest, TestUpdate) { TestUpdate(); } TEST_F(OptimizerTest, TestCheckPoint) { TestCheckPoint(); } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/paddle/optimizer/serialization_test.cc b/paddle/optimizer/serialization_test.cc index 4c416f55ee0bd70f9ec6e288b08a5399d8b2bf39..940e941e9042d8a37363311867df5bb477b3dac0 100644 --- a/paddle/optimizer/serialization_test.cc +++ b/paddle/optimizer/serialization_test.cc @@ -46,8 +46,3 @@ TEST(TensorToProto, Case2) { EXPECT_EQ(t1[i], t[i]); } } - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index a2fdc5ce69bfdf0fadb808e4b51c8eef4ff7dfd6..502637c881208e53dd832a9759b3873ef1988395 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -183,6 +183,7 @@ EOF ${DOCKERFILE_GPU_ENV} ADD go/cmd/pserver/pserver /usr/bin/ ADD go/cmd/master/master /usr/bin/ + ADD paddle/pybind/print_operators_doc /usr/bin/ # default command shows the paddle version and exit CMD ["paddle", "version"] EOF diff --git a/paddle/testing/CMakeLists.txt b/paddle/testing/CMakeLists.txt index 4245df5ab72bf0fd67261818b307f0babdb5d685..2275c950ba13d131fa17fef1845a50a22fd55a2f 100644 --- a/paddle/testing/CMakeLists.txt +++ b/paddle/testing/CMakeLists.txt @@ -5,4 +5,6 @@ if(WITH_TESTING) add_dependencies(paddle_test_main paddle_proto ${external_project_dependencies}) add_library(paddle_test_util STATIC TestUtil.cpp) add_dependencies(paddle_test_util paddle_proto ${external_project_dependencies}) + add_library(paddle_gtest_main STATIC paddle_gtest_main.cc) + add_dependencies(paddle_gtest_main paddle_memory gtest gflags) endif() diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..a491322b7e533f7a9c263a249494440269391003 --- /dev/null +++ b/paddle/testing/paddle_gtest_main.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "gflags/gflags.h" +#include "gtest/gtest.h" +#include "paddle/memory/memory.h" + +int main(int argc, char** argv) { + std::vector new_argv; + std::string gflags_env; + new_argv.push_back(argv[0]); +#ifdef PADDLE_WITH_CUDA + new_argv.push_back( + strdup("--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory")); +#else + new_argv.push_back(strdup("--tryfromenv=use_pinned_memory")); +#endif + int new_argc = static_cast(new_argv.size()); + char** new_argv_address = new_argv.data(); + google::ParseCommandLineFlags(&new_argc, &new_argv_address, false); + testing::InitGoogleTest(&argc, argv); + paddle::memory::Used(paddle::platform::CPUPlace()); +#ifdef PADDLE_WITH_CUDA + paddle::memory::Used(paddle::platform::GPUPlace(0)); +#endif + return RUN_ALL_TESTS(); +} diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index c033b27beab52a979c78caeba68990c95b462c56..dd25bc19ec5f4fd6eb3e04f304b1de488e988f41 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -36,7 +36,8 @@ def __read_gflags_from_env__(): read_env_flags = ['use_pinned_memory'] if core.is_compile_gpu(): read_env_flags.append('fraction_of_gpu_memory_to_use') - core.init_gflags(sys.argv + ["--tryfromenv=" + ",".join(read_env_flags)]) + core.init_gflags([sys.argv[0]] + + ["--tryfromenv=" + ",".join(read_env_flags)]) __read_gflags_from_env__() diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 1c42e4d44f5046e0db171fdaeb8e7af38a2cae07..49c6d8983457fa9c29451b8d020dd0c581481f9c 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -3,10 +3,12 @@ import collections import numpy as np from . import core import proto.framework_pb2 as framework_pb2 +import contextlib __all__ = [ 'Block', 'Variable', 'Program', 'Operator', 'default_startup_program', - 'default_main_program' + 'default_main_program', 'program_guard', 'switch_startup_program', + 'switch_main_program' ] @@ -659,8 +661,83 @@ _startup_program_ = Program() def default_startup_program(): + """ + Get default startup program. In startup program, Paddle will initialize + parameters, initialize nccl handle, etc. + + Returns: + Program: startup program + """ return _startup_program_ def default_main_program(): + """ + Get default main program. The main program is used for training or testing. + + Returns: + Program: main program + """ return _main_program_ + + +def switch_main_program(program): + """ + Switch the main program to a new program. + + Args: + program(Program): The new main program + + Returns: + Program: The previous main program + """ + global _main_program_ + prev_program = _main_program_ + _main_program_ = program + return prev_program + + +def switch_startup_program(program): + """ + Switch the startup program to a new program + Args: + program(Program): The new startup program + + Returns: + Program: The previous startup program + """ + global _startup_program_ + prev_program = _startup_program_ + _startup_program_ = program + return prev_program + + +@contextlib.contextmanager +def program_guard(main_program, startup_program=None): + """ + Switch program with `with` statement + + Examples: + >>> with program_guard(Program()): + >>> data = fluid.layers.data(...) + >>> hidden = fluid.layers.fc(...) + + Args: + main_program(Program): New main program inside `with` statement + startup_program(Program): New startup program inside `with` statement. + None means do not change startup program. + + Returns: + None + """ + if not isinstance(main_program, Program): + raise TypeError("main_program should be Program") + main_program = switch_main_program(main_program) + if startup_program is not None: + if not isinstance(startup_program, Program): + raise TypeError("startup_program should be Program") + startup_program = switch_startup_program(startup_program) + yield + switch_main_program(main_program) + if startup_program is not None: + switch_startup_program(startup_program) diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index b6906be60b8ffb7c7afc220ad4f40c6f60a0b112..33b0e54f42afc82beaa24e334023f30a4035f039 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -1,192 +1,141 @@ +from __future__ import print_function import unittest import paddle.v2.fluid.layers as layers import paddle.v2.fluid.nets as nets -from paddle.v2.fluid.framework import Program +from paddle.v2.fluid.framework import Program, program_guard class TestBook(unittest.TestCase): def test_fit_a_line(self): program = Program() - x = layers.data( - name='x', shape=[13], dtype='float32', main_program=program) - y_predict = layers.fc(input=x, size=1, act=None, main_program=program) + with program_guard(program, startup_program=Program()): + x = layers.data(name='x', shape=[13], dtype='float32') + y_predict = layers.fc(input=x, size=1, act=None) + y = layers.data(name='y', shape=[1], dtype='float32') + cost = layers.square_error_cost(input=y_predict, label=y) + avg_cost = layers.mean(x=cost) + self.assertIsNotNone(avg_cost) + program.append_backward(avg_cost) - y = layers.data( - name='y', shape=[1], dtype='float32', main_program=program) - cost = layers.square_error_cost( - input=y_predict, label=y, main_program=program) - - avg_cost = layers.mean(x=cost, main_program=program) - self.assertIsNotNone(avg_cost) - program.append_backward(avg_cost) - - print str(program) + print(str(program)) def test_recognize_digits_mlp(self): program = Program() - - # Change g_program, so the rest layers use `g_program` - images = layers.data( - name='pixel', shape=[784], dtype='float32', main_program=program) - label = layers.data( - name='label', shape=[1], dtype='int32', main_program=program) - hidden1 = layers.fc(input=images, - size=128, - act='relu', - main_program=program) - hidden2 = layers.fc(input=hidden1, - size=64, - act='relu', - main_program=program) - predict = layers.fc(input=hidden2, - size=10, - act='softmax', - main_program=program) - cost = layers.cross_entropy( - input=predict, label=label, main_program=program) - avg_cost = layers.mean(x=cost, main_program=program) - self.assertIsNotNone(avg_cost) - - print str(program) + with program_guard(program, startup_program=Program()): + # Change g_program, so the rest layers use `g_program` + images = layers.data(name='pixel', shape=[784], dtype='float32') + label = layers.data(name='label', shape=[1], dtype='int32') + hidden1 = layers.fc(input=images, size=128, act='relu') + hidden2 = layers.fc(input=hidden1, size=64, act='relu') + predict = layers.fc(input=hidden2, size=10, act='softmax') + cost = layers.cross_entropy(input=predict, label=label) + avg_cost = layers.mean(x=cost) + self.assertIsNotNone(avg_cost) + + print(str(program)) def test_simple_conv2d(self): program = Program() - images = layers.data( - name='pixel', - shape=[3, 48, 48], - dtype='int32', - main_program=program) - layers.conv2d( - input=images, - num_filters=3, - filter_size=[4, 4], - main_program=program) - - print str(program) + with program_guard(program, startup_program=Program()): + images = layers.data(name='pixel', shape=[3, 48, 48], dtype='int32') + layers.conv2d(input=images, num_filters=3, filter_size=[4, 4]) + + print(str(program)) def test_conv2d_transpose(self): program = Program() - kwargs = {'main_program': program} - img = layers.data( - name='pixel', shape=[3, 2, 2], dtype='float32', **kwargs) - layers.conv2d_transpose( - input=img, num_filters=10, output_size=28, **kwargs) - print str(program) + with program_guard(program): + img = layers.data(name='pixel', shape=[3, 2, 2], dtype='float32') + layers.conv2d_transpose(input=img, num_filters=10, output_size=28) + print(str(program)) def test_recognize_digits_conv(self): program = Program() - - images = layers.data( - name='pixel', - shape=[1, 28, 28], - dtype='float32', - main_program=program) - label = layers.data( - name='label', shape=[1], dtype='int32', main_program=program) - conv_pool_1 = nets.simple_img_conv_pool( - input=images, - filter_size=5, - num_filters=2, - pool_size=2, - pool_stride=2, - act="relu", - main_program=program) - conv_pool_2 = nets.simple_img_conv_pool( - input=conv_pool_1, - filter_size=5, - num_filters=4, - pool_size=2, - pool_stride=2, - act="relu", - main_program=program) - - predict = layers.fc(input=conv_pool_2, - size=10, - act="softmax", - main_program=program) - cost = layers.cross_entropy( - input=predict, label=label, main_program=program) - avg_cost = layers.mean(x=cost, main_program=program) - - program.append_backward(avg_cost) - - print str(program) + with program_guard(program, startup_program=Program()): + images = layers.data( + name='pixel', shape=[1, 28, 28], dtype='float32') + label = layers.data(name='label', shape=[1], dtype='int32') + conv_pool_1 = nets.simple_img_conv_pool( + input=images, + filter_size=5, + num_filters=2, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_2 = nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=4, + pool_size=2, + pool_stride=2, + act="relu") + + predict = layers.fc(input=conv_pool_2, size=10, act="softmax") + cost = layers.cross_entropy(input=predict, label=label) + avg_cost = layers.mean(x=cost) + + program.append_backward(avg_cost) + + print(str(program)) def test_word_embedding(self): program = Program() - dict_size = 10000 - embed_size = 32 - first_word = layers.data( - name='firstw', shape=[1], dtype='int64', main_program=program) - second_word = layers.data( - name='secondw', shape=[1], dtype='int64', main_program=program) - third_word = layers.data( - name='thirdw', shape=[1], dtype='int64', main_program=program) - forth_word = layers.data( - name='forthw', shape=[1], dtype='int64', main_program=program) - next_word = layers.data( - name='nextw', shape=[1], dtype='int64', main_program=program) - - embed_first = layers.embedding( - input=first_word, - size=[dict_size, embed_size], - dtype='float32', - param_attr='shared_w', - main_program=program) - embed_second = layers.embedding( - input=second_word, - size=[dict_size, embed_size], - dtype='float32', - param_attr='shared_w', - main_program=program) - - embed_third = layers.embedding( - input=third_word, - size=[dict_size, embed_size], - dtype='float32', - param_attr='shared_w', - main_program=program) - embed_forth = layers.embedding( - input=forth_word, - size=[dict_size, embed_size], - dtype='float32', - param_attr='shared_w', - main_program=program) - - concat_embed = layers.concat( - input=[embed_first, embed_second, embed_third, embed_forth], - axis=1, - main_program=program) - - hidden1 = layers.fc(input=concat_embed, - size=256, - act='sigmoid', - main_program=program) - predict_word = layers.fc(input=hidden1, - size=dict_size, - act='softmax', - main_program=program) - cost = layers.cross_entropy( - input=predict_word, label=next_word, main_program=program) - avg_cost = layers.mean(x=cost, main_program=program) - self.assertIsNotNone(avg_cost) - - print str(program) + with program_guard(program, startup_program=Program()): + dict_size = 10000 + embed_size = 32 + first_word = layers.data(name='firstw', shape=[1], dtype='int64') + second_word = layers.data(name='secondw', shape=[1], dtype='int64') + third_word = layers.data(name='thirdw', shape=[1], dtype='int64') + forth_word = layers.data(name='forthw', shape=[1], dtype='int64') + next_word = layers.data(name='nextw', shape=[1], dtype='int64') + + embed_first = layers.embedding( + input=first_word, + size=[dict_size, embed_size], + dtype='float32', + param_attr='shared_w') + embed_second = layers.embedding( + input=second_word, + size=[dict_size, embed_size], + dtype='float32', + param_attr='shared_w') + + embed_third = layers.embedding( + input=third_word, + size=[dict_size, embed_size], + dtype='float32', + param_attr='shared_w') + embed_forth = layers.embedding( + input=forth_word, + size=[dict_size, embed_size], + dtype='float32', + param_attr='shared_w') + + concat_embed = layers.concat( + input=[embed_first, embed_second, embed_third, embed_forth], + axis=1) + + hidden1 = layers.fc(input=concat_embed, size=256, act='sigmoid') + predict_word = layers.fc(input=hidden1, + size=dict_size, + act='softmax') + cost = layers.cross_entropy(input=predict_word, label=next_word) + avg_cost = layers.mean(x=cost) + self.assertIsNotNone(avg_cost) + + print(str(program)) def test_linear_chain_crf(self): program = Program() - - # Change g_program, so the rest layers use `g_program` - images = layers.data( - name='pixel', shape=[784], dtype='float32', main_program=program) - label = layers.data( - name='label', shape=[1], dtype='int32', main_program=program) - hidden = layers.fc(input=images, size=128, main_program=program) - crf = layers.linear_chain_crf( - input=hidden, label=label, main_program=program) - - print str(program) + with program_guard(program, startup_program=Program()): + images = layers.data(name='pixel', shape=[784], dtype='float32') + label = layers.data(name='label', shape=[1], dtype='int32') + hidden = layers.fc(input=images, size=128) + crf = layers.linear_chain_crf(input=hidden, label=label) + self.assertNotEqual(crf, None) + + print(str(program)) if __name__ == '__main__': diff --git a/python/paddle/v2/fluid/tests/test_unpool_op.py b/python/paddle/v2/fluid/tests/test_unpool_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e87f283042c081ed9f232d140ff8c303cd3d1858 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_unpool_op.py @@ -0,0 +1,83 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings): + s0, s1, s2, s3 = input.shape + out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0] + out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1] + out = np.zeros((s0, s1, out_hsize, out_wsize)) + for nidx in xrange(s0): + for cidx in xrange(s1): + for h in xrange(s2): + for w in xrange(s3): + index = indices[nidx, cidx, h, w] + hidx = (index - index % out_wsize) / out_wsize + widx = index % out_wsize + out[nidx, cidx, int(hidx), int(widx)] = \ + input[nidx, cidx, h, w] + + return out + + +class TestUnpoolOp(OpTest): + def setUp(self): + self.op_type = "unpool" + self.init_test_case() + pre_input = np.random.random(self.shape).astype("float32") + nsize, csize, hsize, wsize = pre_input.shape + hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \ + self.strides[0] + 1 + wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \ + self.strides[1] + 1 + input = np.zeros((nsize, csize, hsize_out, wsize_out)) + indices = np.zeros((nsize, csize, hsize_out, wsize_out)) + for i in xrange(hsize_out): + for j in xrange(wsize_out): + r_start = np.max((i * self.strides[0] - self.paddings[0], 0)) + r_end = np.min((i * self.strides[0] + self.ksize[0] - \ + self.paddings[0], hsize)) + c_start = np.max((j * self.strides[1] - self.paddings[1], 0)) + c_end = np.min((j * self.strides[1] + self.ksize[1] - \ + self.paddings[1], wsize)) + for nidx in xrange(nsize): + for cidx in xrange(csize): + x_masked = pre_input[nidx, cidx, r_start:r_end, \ + c_start:c_end] + input[nidx, cidx, i, j] = x_masked.max() + arg = x_masked.argmax() + indices[nidx, cidx, i, j] = \ + (r_start + arg / self.ksize[1]) * wsize + \ + c_start + arg % self.ksize[1] + output = self.unpool2d_forward_naive(input, indices, self.ksize, \ + self.strides, self.paddings).astype("float32") + self.inputs = { + 'X': input.astype('float32'), + 'Indices': indices.astype('int32') + } + self.attrs = { + 'strides': self.strides, + 'paddings': self.paddings, + 'ksize': self.ksize, + 'unpooling_type': self.unpooling_type, + } + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.unpool2d_forward_naive = unpool2dmax_forward_naive + self.unpooling_type = "max" + self.shape = [6, 4, 5, 5] + self.ksize = [3, 3] + self.strides = [2, 2] + self.paddings = [0, 0] + + +if __name__ == '__main__': + unittest.main()