diff --git a/CMakeLists.txt b/CMakeLists.txt
index de47086dbd6a440cd413c7843c83b1c69d9841b2..23bbe829ac16180088bfa37df66e23f19b021ea3 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -39,7 +39,6 @@ option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_F
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
-option(WITH_TENSORRT "Compile PaddlePaddle with TensorRT support." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" OFF)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
@@ -180,13 +179,9 @@ set(EXTERNAL_LIBS
if(WITH_GPU)
include(cuda)
+ include(tensorrt)
endif(WITH_GPU)
-# TensorRT depends on GPU.
-if (NOT WITH_GPU)
- set(WITH_TENSORRT OFF)
-endif()
-
if(WITH_AMD_GPU)
find_package(HIP)
include(hip)
diff --git a/Dockerfile b/Dockerfile
index 9097bb657d2366997112ec7662762a93358aa647..870304a6acc99e715dffbfabd8058be000b6872c 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -46,7 +46,7 @@ ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
RUN curl -s -q https://glide.sh/get | sh
# Install TensorRT
-# The unnecessary files has been removed to make the library small.
+# The unnecessary files has been removed to make the library small. It only contains include and lib now.
RUN wget -qO- http://paddlepaddledeps.bj.bcebos.com/TensorRT-4.0.0.3.Ubuntu-16.04.4.x86_64-gnu.cuda-8.0.cudnn7.0.tar.gz | \
tar -xz -C /usr/local && \
cp -rf /usr/local/TensorRT/include /usr && \
diff --git a/Dockerfile.android b/Dockerfile.android
index cc022d596b4b74dd1e4f4d0901dd81c91a7decd1..848a7eba6f1421432addae8acff407b611adb4ae 100644
--- a/Dockerfile.android
+++ b/Dockerfile.android
@@ -27,7 +27,7 @@ RUN git config --global credential.helper store
# Fix locales to en_US.UTF-8
RUN localedef -i en_US -f UTF-8 en_US.UTF-8
-RUN pip install --upgrade pip && \
+RUN pip install --upgrade pip==9.0.3 && \
pip install -U 'protobuf==3.1.0' && \
pip install -U wheel sphinx && \
pip install pre-commit
diff --git a/paddle/scripts/check_env.sh b/benchmark/paddle/image/check_env.sh
similarity index 100%
rename from paddle/scripts/check_env.sh
rename to benchmark/paddle/image/check_env.sh
diff --git a/cmake/configure.cmake b/cmake/configure.cmake
index f726405c4773994f6ca6509e5218750805b03995..e490397cc0624c310949a4b571bd00cac6e8953b 100644
--- a/cmake/configure.cmake
+++ b/cmake/configure.cmake
@@ -80,6 +80,16 @@ if(WITH_GPU)
# Include cuda and cudnn
include_directories(${CUDNN_INCLUDE_DIR})
include_directories(${CUDA_TOOLKIT_INCLUDE})
+
+ if(TENSORRT_FOUND)
+ if(${CUDA_VERSION_MAJOR} VERSION_LESS 8)
+ message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile")
+ endif()
+ if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
+ message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile")
+ endif()
+ include_directories(${TENSORRT_INCLUDE_DIR})
+ endif()
elseif(WITH_AMD_GPU)
add_definitions(-DPADDLE_WITH_HIP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__")
diff --git a/cmake/tensorrt.cmake b/cmake/tensorrt.cmake
new file mode 100644
index 0000000000000000000000000000000000000000..0c07d36bed65400164853b99f18ec0335341cd94
--- /dev/null
+++ b/cmake/tensorrt.cmake
@@ -0,0 +1,33 @@
+if(NOT WITH_GPU)
+ return()
+endif()
+
+set(TENSORRT_ROOT "/usr" CACHE PATH "TENSORRT ROOT")
+find_path(TENSORRT_INCLUDE_DIR NvInfer.h
+ PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/include
+ $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/include
+ NO_DEFAULT_PATH
+)
+
+find_library(TENSORRT_LIBRARY NAMES libnvinfer.so libnvinfer.a
+ PATHS ${TENSORRT_ROOT} ${TENSORRT_ROOT}/lib
+ $ENV{TENSORRT_ROOT} $ENV{TENSORRT_ROOT}/lib
+ NO_DEFAULT_PATH
+ DOC "Path to TensorRT library.")
+
+if(TENSORRT_INCLUDE_DIR AND TENSORRT_LIBRARY)
+ set(TENSORRT_FOUND ON)
+else()
+ set(TENSORRT_FOUND OFF)
+endif()
+
+if(TENSORRT_FOUND)
+ file(READ ${TENSORRT_INCLUDE_DIR}/NvInfer.h TENSORRT_VERSION_FILE_CONTENTS)
+ string(REGEX MATCH "define NV_TENSORRT_MAJOR +([0-9]+)" TENSORRT_MAJOR_VERSION
+ "${TENSORRT_VERSION_FILE_CONTENTS}")
+ string(REGEX REPLACE "define NV_TENSORRT_MAJOR +([0-9]+)" "\\1"
+ TENSORRT_MAJOR_VERSION "${TENSORRT_MAJOR_VERSION}")
+
+ message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
+ "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
+endif()
diff --git a/doc/fluid/api/initializer.rst b/doc/fluid/api/initializer.rst
index ee69925fda6b3fc850cfb632e8edd359e7fcff9c..f186c9c85a640da49d95a1a62c721b09b3007d83 100644
--- a/doc/fluid/api/initializer.rst
+++ b/doc/fluid/api/initializer.rst
@@ -33,3 +33,45 @@ Xavier
:members:
:noindex:
+MSRA
+------
+
+.. autoclass:: paddle.fluid.initializer.MSRA
+ :members:
+ :noindex:
+
+ConstantInitializer
+-------------------
+
+.. autoclass:: paddle.fluid.initializer.ConstantInitializer
+ :members:
+ :noindex:
+
+UniformInitializer
+------------------
+
+.. autoclass:: paddle.fluid.initializer.UniformInitializer
+ :members:
+ :noindex:
+
+NormalInitializer
+-----------------
+
+.. autoclass:: paddle.fluid.initializer.NormalInitializer
+ :members:
+ :noindex:
+
+XavierInitializer
+-----------------
+
+.. autoclass:: paddle.fluid.initializer.XavierInitializer
+ :members:
+ :noindex:
+ MSRA
+ ------
+
+MSRAInitializer
+-----------------
+.. autoclass:: paddle.fluid.initializer.MSRAInitializer
+ :members:
+ :noindex:
diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst
index 5c02886efd7d11e9520910526fb90ec01e123bae..3790f09c84563fe541bd8d0bc08e23b19d4287ca 100644
--- a/doc/fluid/api/layers.rst
+++ b/doc/fluid/api/layers.rst
@@ -815,3 +815,8 @@ zeros
.. autofunction:: paddle.fluid.layers.zeros
:noindex:
+topk
+----
+
+.. autofunction:: paddle.fluid.layers.topk
+ :noindex:
diff --git a/doc/fluid/design/concepts/parallel_executor.md b/doc/fluid/design/concepts/parallel_executor.md
index 9aed3b059a1595ba3971d7d5acfc0d16a731584b..4f88e27bed722e9f2f535e368926fe49b4e72e56 100644
--- a/doc/fluid/design/concepts/parallel_executor.md
+++ b/doc/fluid/design/concepts/parallel_executor.md
@@ -84,7 +84,7 @@ Running an operator can be asynchronized. There is a thread pool to execute an `
## Synchronize GPU Kernels
-The GPU is a non-blocking device. The different streams need be synchronized when switing streams. In current implementation, the synchronization based on the following algorithm:
+The GPU is a non-blocking device. The different streams need be synchronized when switching streams. In current implementation, the synchronization based on the following algorithm:
1. `OpHandle` will record `DeviceContext` that it is used.
2. In `OpHandle::Run`, if the `DeviceContext` of current operator is different from `DeviceContext` of any input variable, just wait the generate operator of this input variable.
diff --git a/doc/fluid/design/dist_train/README.md b/doc/fluid/design/dist_train/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2dd652d8bdcb8f3b6e759347bd55b217be909386
--- /dev/null
+++ b/doc/fluid/design/dist_train/README.md
@@ -0,0 +1,57 @@
+## Distributed training overview doc
+
+Currently Paddle Fluid use parameter server architecture to support distributed training.
+
+For synchronous and asynchronous training, the differences are mostly in the logic of parameter server. Now we have already support synchronous training.
+
+### Synchronous training
+
+The training process of synchronous training is:
+
+![synchronous distributed training](./src/sync_distributed_training.png)
+
+1. Pserver
+ 1. set `barrier_condition_` to 0 and waits for trainers to send gradient.
+1. Trainer
+ 1. Trainer read minibatch of data, run forward-backward with local parameter copy and get the gradients for parameters.
+ 1. Trainer use split op to split all the gradient into blocks. The split method is determined at compile time.
+ 1. Trainer use send_op to send all the split gradients to corresponding parameter server.
+ 1. After trainer send all the gradients, it will send a `BATCH_BARRIER_MESSAGE` to all pservers.
+ 1. Trainer call GetVariable to pserver and wait for `barrier_condition_` on pserver to be 1.
+1. Pserver
+ 1. Pserver will count the number of `BATCH_BARRIER_MESSAGE`.
+ 1. When the count of `BATCH_BARRIER_MESSAGE` is equal to the number of Trainer. Pserver thinks it received all gradient from all trainers.
+ 1. Pserver will run the optimization block to optimize the parameters.
+ 1. After optimization, pserver set `barrier_condition_` to 1.
+ 1. Pserver wait for `FETCH_BARRIER_MESSAGE`.
+1. Trainer.
+ 1. The trainer uses GetVariable to get all the parameters from pserver.
+ 1. Trainer sends a `FETCH_BARRIER_MESSAGE` to each pserver.
+1. Pserver.
+ 1. when the number of `FETCH_BARRIER_MESSAGE` reach the number of all trainers. Pserver think all the parameters have been got. it will go back to 1. to set `barrier_condition_` to 0.
+
+### Asynchronous training
+In the above process. There are two barriers for all trainers to synchronize with each other. In asynchronous training, these two barriers are not needed. The trainer can just send gradients to pserver and then get parameters back.
+
+The training process of asynchronous training can be:
+
+![asynchronous distributed training](./src/async_distributed_training.png)
+
+1. Pserver:
+ 1. Each parameter has a queue to receive its gradient from trainers.
+ 1. Each parameter has a thread to read data from the queue and run optimize block, using the gradient to optimize the parameter.
+ 1. Using an independent thread to handle RPC call `GetVariable` for trainers to get parameters back.(Maybe here we should use a thread pool to speed up fetching the parameters.)
+
+1. Trainer:
+ 1. Trainer read a batch of data. Run forward and backward with local parameter copy and get the gradients for parameters.
+ 1. Trainer split all gradients to blocks and then send these gradient blocks to pservers(pserver will put them into the queue).
+ 2. Trainer gets all parameters back from pserver.
+
+### Note:
+There are also some conditions that need to consider. For exmaple:
+
+1. If trainer needs to wait for the pserver to apply it's gradient and then get back the parameters back.
+1. If we need a lock between parameter update and parameter fetch.
+1. If one parameter must be on one server, or it can also be split and send to multiple parameter servers.
+
+The above architecture of asynchronous training can support different mode, we can have a detailed test in the future for these problems.
diff --git a/doc/fluid/design/dist_train/async_update.md b/doc/fluid/design/dist_train/async_update.md
new file mode 100644
index 0000000000000000000000000000000000000000..6a0835b761b69030ba30697e6e8863928efbf57f
--- /dev/null
+++ b/doc/fluid/design/dist_train/async_update.md
@@ -0,0 +1,58 @@
+# Design Doc: Asynchronous Update With Distributed Training
+
+## Background
+
+For the typical synchronous distributed training, some significant steps are as follows:
+
+1. A Trainer will compute the gradients and SEND them to the Parameter Server(PServer) nodes.
+1. After the PServer node received gradients came from all the Trainers, It will aggregate the
+gradient variables for the same parameter into one gradient variable and then apply the aggregated
+gradient to the respective parameter, finally using an optimize algorithms(SGD, Monument...)
+to update the parameters.
+1. The Trainer would wait for the PServers finished the optimize stage, and GET the parameters from PServer,
+so all the Trainers would get the same parameters.
+
+In the synchronously distributed training, there should be a `Barrier` to synchronise the
+parameters after the optimizing stage. The performance of a distributed training job would
+depend on the slowest node if there were hundreds or thousands of training nodes in a
+Job, the performance of synchronously distributed training might be very poor because of
+the slow node. So this design doc would introduce an approach to implement
+*asynchronously* distributed training in PaddlePaddle Fluid.
+
+## Design
+
+
+
+As the figure above, we describe a global view of asynchronously update process and use
+the parameter `w1` as an example to introduce the steps:
+1. For each gradient variables, they may distribute on different GPU card and aggregate
+them while they are all calculated.
+1. Split the gradient variable into multiple blocks according to the number of PServer
+instances and then send them.
+1. PServer would run an `Optimize Block` using a specified optimize algorithm to update
+the specified parameter.
+1. The trainer will fetch latest parameter from PServer before running forward Op which depends
+on the specified parameter.
+1. Broadcast the received variable into multiple GPU cards and continue to run the next
+mini-batch.
+
+### Trainer
+
+- For the multiple devices distributed training, we need to aggregate the gradient
+variables which placed on different devices firstly and then schedule a `SendVars` Operator to
+send the gradient variables to the multiple PServer instances.
+- Schedule `FetchVars` operator to fetch the latest parameter from PServer before running
+the forward ops.
+- There could be a large number of gradient variables to be sent, so we need to use another
+thread pool(IO Threadpool) whose a number of the schedulable threads is larger than the
+computing thread pool to avoid competitive the thread resources with computing.
+
+### Parameter Server
+
+
+
+- There should be multiple trainer instances want to optimize the same parameter at
+the same time, to avoid the racing, we need one `BlockingQueue` for each gradient
+variable to process them one by one.
+- We need a `Map` structure to map a gradient variable name to the `OptimizeBlock` which
+can optimize the respective parameter.
diff --git a/doc/fluid/design/dist_train/mpi_enabled_design.md b/doc/fluid/design/dist_train/mpi_enabled_design.md
new file mode 100644
index 0000000000000000000000000000000000000000..4ad3afc7b7522c60460c6f1f387f9415d3738778
--- /dev/null
+++ b/doc/fluid/design/dist_train/mpi_enabled_design.md
@@ -0,0 +1,46 @@
+# MPI-enabled PaddlePaddle Design doc
+
+# Background
+When we do distribute multi GPU training, the communication overhead between servers become the major bottleneck, because of the following reasons:
+1. Must copy at least once from GPU to CPU memory so that the data can be ready to transfer. And for the pserver side, copy data from CPU to GPU introduce more overhead.
+2. GPU->CPU data transfer is 10 times slower than data transfer between GPUs or between PCIe devices.
+3. TCP connections can not make full use of RDMA 100Gb devices.
+
+We will use OpenMPI API to PaddlePaddle, which can bring two benefits to PaddlePaddle:
+1. Enable RDMA with PaddlePaddle, which bring high-performance low latency networks.
+2. Enable GPUDriect with PaddlePaddle, which bring the highest throughput and lowest latency GPU read and write.
+
+# Change list
+* Compile args: Need add compile args to enable MPI support.
+* Execute args: Need add execute args to assign when and how to use MPI operations.
+* New ops: Need new op ```mpi_send_op``` and ```mpi_listenandserve_op``` to support MPI send and receive.
+* Transpiler optimized: Which can add ```mpi_send_op``` and ```mpi_listenandserve_op``` to the running graph.
+* MPI utils package: Need MPI utils package as the low-level API supported.
+
+## Compile args
+Because MPI or CUDA need hardware supported, so we will add compile args to enable MPI support and control compiling.Add ```WITH_MPI``` compile args to control MPI to use or not. If the ```WITH_MPI``` is ```ON```, compile system will find openMPI codes in configuration. We should prepare openMPI environment before compiling.
+
+## Execute args
+Launch the script using the ```mpirun``` launcher, For example: ```mpirun -np 3 -hosts node1,node2,node3 python train.py```. By doing this, We can number the actors (trainer/pserver/master) with o .. (n-1). The node's number is the Rank of the calling process in a group of comm (integer), The MPI processes identify each other using a Rank ID. We have to create a mapping between PaddlePaddle's nodes and their Rank ID so that we can communicate with the correct destinations when using MPI operations.
+
+## New ops
+We won't replace all the gRPC requests to MPI requests, the standard gRPC library is used for all administrative operations and the MPI API will be used to transfer tensor or selectRows to Pservers. The base of this idea, we create two new operators to handle requests and receives, the two operators are ```mpi_send_op``` and ```mpi_listenandserve_op```. They are a little similar to [send_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/send_op.cc) and [listen_and_serv_op](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/listen_and_serv_op.cc), also, We will build a new module to package MPI send and receive process.
+
+### mpi_send_op
+Very similar with ```send_op```, we will replace gRPC code which used to send gradient with ```mpi_module```, at the same time, we will wrap it with ```framework::Async```.
+
+### mpi_listenandserve_op
+Very similar with ```listen_and_serv_op```, we will replace gRPC code which used to receive gradient with ```mpi_module```, at the same time, we will wrap it with ```framework::Async```.
+
+## Transpiler optimized
+**We can get env ```OMPI_COMM_WORLD_SIZE``` and ```OMPI_COMM_WORLD_RANK``` to distinguish use MPI or not, If we use openMPI, the variable in env must exist.**
+ if confirm to use MPI, we will modify ```send_op``` to ```mpi_send_op``` in distribute_transpiler, and modify ```listenandserve_op``` to ```mpi_listenandserve_op``` also.
+
+## MPI utils package
+In this package, We will write openMPI low-level API to use MPI.
+The API included in this package are:
+* MPI send and receive module, We will build a new module to package MPI send and receive process. MPI send and receive are different to gRPC, the MPI [recvice](https://www.open-mpi.org/doc/v1.8/man3/MPI_Irecv.3.php) must know receive buffer size and receive buffer element. For this reason, We have to make communications twice, the first one is to send metadata about gradient through gRPC, the second one is the real communication through MPI which send gradient data to mpi_listenandserve_op.
+The detailed flow is below:
+![](https://github.com/seiriosPlus/Paddle/blob/mpi_enabled/doc/fluid/design/dist_train/src/mpi_module.png)
+* MPI global configurations, which store the Rank ID and the mapping in global variables, for example:
+gRPC client : MPI nodes :``` 127.0.0.1:32004 : 3 ```
diff --git a/doc/fluid/design/dist_train/src/async_distributed_training.png b/doc/fluid/design/dist_train/src/async_distributed_training.png
new file mode 100644
index 0000000000000000000000000000000000000000..3b53ab59c0cd7b44b2956f16f1adc47fe85909d3
Binary files /dev/null and b/doc/fluid/design/dist_train/src/async_distributed_training.png differ
diff --git a/doc/fluid/design/dist_train/src/async_pserver.graffle b/doc/fluid/design/dist_train/src/async_pserver.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..d2301611774fcb3866473e3e6470568d1e1312cf
Binary files /dev/null and b/doc/fluid/design/dist_train/src/async_pserver.graffle differ
diff --git a/doc/fluid/design/dist_train/src/async_pserver.png b/doc/fluid/design/dist_train/src/async_pserver.png
new file mode 100644
index 0000000000000000000000000000000000000000..7d900b0c0eb291c67537b9cf93227c671bafdc73
Binary files /dev/null and b/doc/fluid/design/dist_train/src/async_pserver.png differ
diff --git a/doc/fluid/design/dist_train/src/async_update.graffle b/doc/fluid/design/dist_train/src/async_update.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..3a631888688a0d564a873fcb16d943958c91223e
Binary files /dev/null and b/doc/fluid/design/dist_train/src/async_update.graffle differ
diff --git a/doc/fluid/design/dist_train/src/async_update.png b/doc/fluid/design/dist_train/src/async_update.png
new file mode 100644
index 0000000000000000000000000000000000000000..3e8db973f45d6d9ac8dcce1dc7878067e79e6dcc
Binary files /dev/null and b/doc/fluid/design/dist_train/src/async_update.png differ
diff --git a/doc/fluid/design/dist_train/src/distributed_training.graffle b/doc/fluid/design/dist_train/src/distributed_training.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..1168801bc1fadfce310a74cb3110695bd1629f6b
Binary files /dev/null and b/doc/fluid/design/dist_train/src/distributed_training.graffle differ
diff --git a/doc/fluid/design/dist_train/src/mpi_module.png b/doc/fluid/design/dist_train/src/mpi_module.png
new file mode 100644
index 0000000000000000000000000000000000000000..e6b6a3e5d6f68baeeb67d7f71154bd8d85f32b6f
Binary files /dev/null and b/doc/fluid/design/dist_train/src/mpi_module.png differ
diff --git a/doc/fluid/design/dist_train/src/sync_distributed_training.png b/doc/fluid/design/dist_train/src/sync_distributed_training.png
new file mode 100644
index 0000000000000000000000000000000000000000..e4f9a221fea4b7238e8a1d84e609c0371f6ef7a2
Binary files /dev/null and b/doc/fluid/design/dist_train/src/sync_distributed_training.png differ
diff --git a/doc/v2/api/data/data_reader.rst b/doc/v2/api/data/data_reader.rst
index 2ccfec9c284877a7576e9751526b169a4ac78d8e..d7c896a6270b488ca4449e5211d0d0879eda6ac5 100644
--- a/doc/v2/api/data/data_reader.rst
+++ b/doc/v2/api/data/data_reader.rst
@@ -6,7 +6,43 @@ Data Reader Interface
DataTypes
=========
-.. automodule:: paddle.v2.data_type
+.. autofunction:: paddle.v2.data_type.dense_array
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.integer_value
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.integer_value_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.integer_value_sub_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_binary_vector
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_binary_vector_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_binary_vector_sub_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_float_vector
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_float_vector_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_float_vector_sub_sequence
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_non_value_slot
+ :noindex:
+
+.. autofunction:: paddle.v2.data_type.sparse_value_slot
+ :noindex:
+
+.. autoclass:: paddle.v2.data_type.InputType
:members:
:noindex:
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 1f3ca24df16cf080d325fbdc0d613a828e384b2a..340b891e41671df7e61a4a66ec538d4603bb9842 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -102,7 +102,7 @@ cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
-cc_test(channel_test SRCS channel_test.cc)
+# cc_test(channel_test SRCS channel_test.cc)
cc_test(tuple_test SRCS tuple_test.cc )
cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op
channel_send_op channel_recv_op sum_op select_op elementwise_add_op compare_op
diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
index dfc52b012f8b6bf5cf1a3feab90dc1ec7842ad6c..bcd61335be0f7fe64563ee65daaf9de0760c9b1a 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
@@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
local_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
-
- vars_.emplace_back(new VarHandle());
- VarHandle* in_var_handle = static_cast(vars_.back().get());
- in_var_handle->place_ = gpu_list_[input_scope_idx];
- in_var_handle->name_ = "input";
- in_var_handle->version_ = 1;
- in_var_handle->scope_idx_ = input_scope_idx;
- in_var_handle->generated_op_ = nullptr;
+ auto* in_var_handle =
+ new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
+ vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
@@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
- vars_.emplace_back(new VarHandle());
- VarHandle* out_var_handle = static_cast(vars_.back().get());
- out_var_handle->place_ = gpu_list_[j];
- out_var_handle->name_ = "out";
- out_var_handle->version_ = 2;
- out_var_handle->scope_idx_ = j;
+ VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
+ vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
diff --git a/paddle/fluid/framework/details/gather_op_handle_test.cc b/paddle/fluid/framework/details/gather_op_handle_test.cc
index 10839f239d59e97946575297a6d125968a1458f4..2da8c89d2df73215b748f102d9bbfc5b742cf97f 100644
--- a/paddle/fluid/framework/details/gather_op_handle_test.cc
+++ b/paddle/fluid/framework/details/gather_op_handle_test.cc
@@ -79,13 +79,8 @@ struct TestGatherOpHandle {
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
- vars_.emplace_back(new VarHandle());
- VarHandle* in_var_handle = static_cast(vars_.back().get());
- in_var_handle->place_ = gpu_list_[j];
- in_var_handle->name_ = "input";
- in_var_handle->version_ = 1;
- in_var_handle->scope_idx_ = j;
- in_var_handle->generated_op_ = nullptr;
+ auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
+ vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
@@ -97,12 +92,9 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle);
// add output
- vars_.emplace_back(new VarHandle());
- VarHandle* out_var_handle = static_cast(vars_.back().get());
- out_var_handle->place_ = gpu_list_[input_scope_idx];
- out_var_handle->name_ = "out";
- out_var_handle->version_ = 2;
- out_var_handle->scope_idx_ = input_scope_idx;
+ auto* out_var_handle =
+ new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
+ vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 5a95cbc53625888bac539f91af391ff0babec17b..d2b6a35a5d5c260b023c68ec4684da95a5b79e81 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -89,105 +89,25 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
- bool change_forward = false;
- if (!is_forwarding) {
- // FIXME(yy): Do not hard code like this
- if (op->OutputArgumentNames().size() == 1 &&
- op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
- continue; // Drop fill 1. for backward coeff;
- }
- }
-
- // append send op if program is distributed trainer main program.
- // always use the first device
- if (!is_forwarding && op->Type() == "send") {
- auto &p = places_[0];
- auto *s = local_scopes_[0];
- // FIXME(wuyi): send op always copy from GPU 0
- result.ops_.emplace_back(new SendOpHandle(*op, s, p));
- // Create inputs for output on original place and no ssa output
- // is created for send op.
- CreateOpHandleIOs(&result, *op, p, 0);
- continue;
- }
-
- for (size_t i = 0; i < places_.size(); ++i) {
- auto &p = places_[i];
- auto *s = local_scopes_[i];
-
- result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
- auto *op_handle = result.ops_.back().get();
- CreateOpHandleIOs(&result, *op, p, i);
-
- auto var_names = op->OutputArgumentNames();
-
- if (is_forwarding) {
- if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
-// Insert ScaleCost OpHandle
-#ifdef PADDLE_WITH_CUDA
- auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p);
-#else
- auto *communication_dev_ctx =
- platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
-#endif
-
- op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
- communication_dev_ctx);
- result.ops_.emplace_back(op_handle);
-
- // FIXME: Currently ScaleLossGradOp only use device_count as scale
- // factor. So it does not depend on any other operators.
- // VarHandle *loss = GetVarHandle(loss_var_name, place);
- // loss->pending_ops_.emplace_back(op_handle);
- // op_handle->inputs_.emplace_back(loss);
-
- CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i);
- change_forward = true;
- }
- }
- }
-
- if (change_forward) {
+ if (op->Type() == "send") {
+ // append send op if program is distributed trainer main program.
+ // always use the first device
+ CreateSendOp(&result, *op);
+ } else if (IsScaleLossOp(*op)) {
+ CreateScaleLossGradOp(&result);
is_forwarding = false;
- }
-
- if (!is_forwarding) {
- auto var_names = op->OutputArgumentNames();
- // Currently, we assume that once gradient is generated, it can be
- // broadcast, and each gradient is only broadcast once. But there are no
- // other cases, for example, we need to adjust the gradient according to
- // the input when we get the gradient, which is not considered at present.
- for (auto &og : var_names) {
- if (grad_names_.count(og) != 0 &&
- og_has_been_broadcast.count(og) == 0) { // is param grad
- // Insert NCCL AllReduce Op
- og_has_been_broadcast.insert(og);
-#ifdef PADDLE_WITH_CUDA
- result.ops_.emplace_back(
- new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
- auto *op_handle = result.ops_.back().get();
-
- for (size_t i = 0; i < places_.size(); ++i) {
- auto &p = places_[i];
- auto &vars = result.vars_[i][og];
-
- if (vars.empty()) { // This device has no data. continue.
- continue;
- }
- auto &prev_grad = vars[vars.size() - 1];
- op_handle->AddInput(prev_grad.get());
-
- vars.emplace_back(new VarHandle);
- auto &var = vars.back();
- var->place_ = p;
- var->name_ = og;
- var->version_ = vars.size() - 1;
-
- op_handle->AddOutput(var.get());
+ } else {
+ CreateComputationalOps(&result, *op);
+ if (!is_forwarding) {
+ // Currently, we assume that once gradient is generated, it can be
+ // broadcast, and each gradient is only broadcast once. But there are no
+ // other cases, for example, we need to adjust the gradient according to
+ // the input when we get the gradient, which is not considered at
+ // present.
+ for (auto &og : op->OutputArgumentNames()) {
+ if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
+ InsertNCCLAllReduceOp(&result, og);
}
-#else
- PADDLE_ENFORCE("Not implemented");
-#endif
}
}
}
@@ -211,7 +131,95 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
}
return std::unique_ptr(graph);
-} // namespace details
+}
+
+void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
+ SSAGraph *result, const std::string &og) const {
+#ifdef PADDLE_WITH_CUDA
+ result->ops_.emplace_back(
+ new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
+ auto *op_handle = result->ops_.back().get();
+
+ for (size_t i = 0; i < places_.size(); ++i) {
+ auto &p = places_[i];
+ auto &vars = result->vars_[i][og];
+ PADDLE_ENFORCE(!vars.empty());
+ auto &prev_grad = vars.back();
+ op_handle->AddInput(prev_grad.get());
+
+ auto var = new VarHandle(vars.size() - 1, i, og, p);
+ vars.emplace_back(var);
+ op_handle->AddOutput(var);
+ }
+#else
+ PADDLE_ENFORCE("Not implemented");
+#endif
+}
+
+bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
+ const std::string &og,
+ std::unordered_set *og_has_been_broadcast) const {
+ bool is_pg_once =
+ grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0;
+ if (is_pg_once) {
+ // Insert NCCL AllReduce Op
+ og_has_been_broadcast->insert(og);
+ }
+ return is_pg_once;
+}
+
+void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
+ for (size_t i = 0; i < places_.size(); ++i) {
+// Insert ScaleCost OpHandle
+#ifdef PADDLE_WITH_CUDA
+ auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]);
+#else
+ auto *communication_dev_ctx =
+ platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
+#endif
+
+ auto *op_handle =
+ new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
+ places_[i], communication_dev_ctx);
+ result->ops_.emplace_back(op_handle);
+
+ // FIXME: Currently ScaleLossGradOp only use device_count as scale
+ // factor. So it does not depend on any other operators.
+ // VarHandle *loss = GetVarHandle(loss_var_name, place);
+ // loss->pending_ops_.emplace_back(op_handle);
+ // op_handle->inputs_.emplace_back(loss);
+
+ CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i],
+ i);
+ }
+}
+
+void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
+ const OpDesc &op) const {
+ for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) {
+ auto p = places_[scope_idx];
+ auto s = local_scopes_[scope_idx];
+ result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
+ CreateOpHandleIOs(result, op, p, scope_idx);
+ }
+}
+
+void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
+ const OpDesc &op) const {
+ auto &p = places_[0];
+ auto *s = local_scopes_[0];
+ // FIXME(wuyi): send op always copy from GPU 0
+ result->ops_.emplace_back(new SendOpHandle(op, s, p));
+ // Create inputs for output on original place and no ssa output
+ // is created for send op.
+ CreateOpHandleIOs(result, op, p, 0);
+}
+
+bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
+ // FIXME(yy): Do not hard code like this
+ return op.OutputArgumentNames().size() == 1 &&
+ op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
+}
} // namespace details
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index f1518d75b421006db6311c3b0f602e47000ab381..b5ba2dbd3c00f23fabd993d7908664db38a31941 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -57,6 +57,20 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
+
+ bool IsScaleLossOp(const OpDesc &op) const;
+
+ void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
+
+ void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const;
+
+ void CreateScaleLossGradOp(SSAGraph *result) const;
+
+ bool IsParameterGradientOnce(
+ const std::string &og,
+ std::unordered_set *og_has_been_broadcast) const;
+
+ void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
};
} // namespace details
} // namespace framework
diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
index 1e48f75958a3ada4d1cd5c8d0f920da4fed2157e..e587210b357ea6caa3272903d8aa6b3e4b2e8228 100644
--- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
+++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc
@@ -73,8 +73,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
+ auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get();
- auto &lod_tensor = s->FindVar(var_name)->Get();
+ auto &lod_tensor = local_scope.FindVar(var_name)->Get();
lod_tensors.emplace_back(lod_tensor);
}
@@ -110,17 +111,21 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
});
} else { // Special handle CPU only Operator's gradient. Like CRF
- auto &trg =
- *this->local_scopes_[0]->Var()->GetMutable();
+ auto &trg = *this->local_scopes_[0]
+ ->FindVar(kLocalExecScopeName)
+ ->Get()
+ ->Var()
+ ->GetMutable();
// Reduce All Tensor to trg in CPU
ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(ToDataType(lod_tensors[0].type()), func);
for (size_t i = 0; i < local_scopes_.size(); ++i) {
- auto &scope = local_scopes_[i];
+ auto &scope =
+ *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get();
auto &p = places_[i];
- auto *var = scope->FindVar(var_name);
+ auto *var = scope.FindVar(var_name);
auto *dev_ctx = dev_ctxes_[p];
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
index 7fb9f99a8a1bc044e2f25f373265a5ec9f7d76d5..7a65ee62c9bfc0dad2ebee3be21de825fa405d73 100644
--- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
+++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc
@@ -30,10 +30,11 @@ ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
void ScaleLossGradOpHandle::RunImpl() {
std::string var_name = static_cast(this->outputs_[0])->name_;
+ auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get();
- float *tmp =
- scope_->FindVar(var_name)->GetMutable()->mutable_data(
- make_ddim({1}), place_);
+ float *tmp = local_scope.FindVar(var_name)
+ ->GetMutable()
+ ->mutable_data(make_ddim({1}), place_);
if (platform::is_cpu_place(place_)) {
*tmp = coeff_;
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc
index be5fb7577581fd99b1b7b80ccdd2acb8d3a91f01..25e8c77bb489546092b2a93e052da7dd0dd5edf4 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.cc
+++ b/paddle/fluid/framework/details/ssa_graph_builder.cc
@@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
- var_holder.emplace_back(new VarHandle);
- auto &init_var = var_holder[0];
- init_var->place_ = place;
- init_var->name_ = each_var_name;
- init_var->generated_op_ = nullptr;
- init_var->version_ = 0;
- var = init_var.get();
+ var = new VarHandle(0, place_offset, each_var_name, place);
+ var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
}
@@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
- vars.emplace_back(new VarHandle());
- auto &var = vars.back();
- var->version_ = version;
- var->name_ = each_var_name;
- var->place_ = place;
- op_handle->AddOutput(var.get());
+ auto var = new VarHandle(version, place_offset, each_var_name, place);
+ vars.emplace_back(var);
+ op_handle->AddOutput(var);
}
template
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index a371ee10fe03cda86c316f3503f9cadb8c716ae5..3d2bd633afff1d453d00faeca3b3dcf77f8dd5d7 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -33,13 +33,6 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
running_ops_(0),
allow_op_delay_(allow_op_delay) {}
-void ThreadedSSAGraphExecutor::RunDelayedOps(
- const std::unordered_set &delayed_ops) {
- for (auto op : delayed_ops) {
- op->Run(use_event_);
- }
-}
-
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector &fetch_tensors) {
std::unordered_map pending_ops;
@@ -51,8 +44,6 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// together since we currently cannot overlap computation and memcpy streams.
// Should revisit it if overlapping is available.
std::unordered_set delayed_ops;
- std::unordered_set blocked_by_delayed_ops;
- std::unordered_set delayed_vars;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
@@ -122,24 +113,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
InsertPendingOp(*op);
}
- auto run_all_ready_ops = [&] {
- for (auto *op : ready_ops) {
- if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
- delayed_ops.insert(op);
- delayed_vars.insert(op->outputs_.begin(), op->outputs_.end());
- ready_vars.Extend(op->outputs_);
- continue;
- }
+ auto run_all_ops = [&](std::unordered_set &set) {
+ for (auto *op : set) {
running_ops_++;
RunOp(&ready_vars, op);
}
- ready_ops.clear();
+ set.clear();
};
// Step 3. Execution
- while (!pending_vars.empty() || !ready_ops.empty() || !delayed_ops.empty()) {
+ while (!pending_vars.empty()) {
// 1. Run All Ready ops
- run_all_ready_ops();
+ // Keep loop until all vars are ready.
+ //
+ // NOTE: DelayedOps have a lower priority. It will be scheduled after all
+ // ready_ops have been performed.
+ if (ready_ops.empty() && allow_op_delay_) {
+ run_all_ops(delayed_ops);
+ } else {
+ run_all_ops(ready_ops);
+ }
// 2. Find ready variable
bool timeout;
@@ -160,29 +153,16 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
- if (delayed_vars.find(ready_var) != delayed_vars.end()) {
- blocked_by_delayed_ops.insert(op);
+ if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
+ delayed_ops.insert(op);
} else {
ready_ops.insert(op);
}
}
}
}
- // When there are no other ops to schedule, schedule buffered delayed
- // ops and unblock other ops.
- if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) {
- RunDelayedOps(delayed_ops);
- delayed_ops.clear();
- for (auto *op : blocked_by_delayed_ops) {
- ready_ops.insert(op);
- }
- blocked_by_delayed_ops.clear();
- }
- // Keep loop until all vars are ready.
}
PADDLE_ENFORCE(ready_ops.empty());
- PADDLE_ENFORCE(delayed_ops.empty());
- PADDLE_ENFORCE(blocked_by_delayed_ops.empty());
// Wait FetchOps.
if (!fetch_ops.empty()) {
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index bb5e837b135c35b5aea403496b45aab1ccc288ff..d70bbd4ef0eb02d1b473bf88e526996819aec5f9 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -88,8 +88,6 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
void RunOp(BlockingQueue *ready_var_q,
details::OpHandleBase *op);
- void RunDelayedOps(const std::unordered_set &delayed_ops);
-
private:
std::unique_ptr<::ThreadPool> pool_;
std::vector local_scopes_;
diff --git a/paddle/fluid/framework/details/var_handle.h b/paddle/fluid/framework/details/var_handle.h
index 871e41343f53b801a22d3a450f0906f37fb372d1..2b887c67e6fc6ea78e42fbb9fd170f740db27d97 100644
--- a/paddle/fluid/framework/details/var_handle.h
+++ b/paddle/fluid/framework/details/var_handle.h
@@ -16,6 +16,7 @@
#include
#include
#include
+#include
#include "paddle/fluid/platform/place.h"
@@ -33,10 +34,10 @@ struct VarHandleBase {
// The operator who generate this variable. nullptr if the variable
// is a root node.
- OpHandleBase *generated_op_;
+ OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready.
- std::unordered_set pending_ops_;
+ std::unordered_set pending_ops_;
};
// VarHandle is actually a single version of Runtime Variable.
@@ -47,6 +48,13 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase {
std::string DebugString() const override;
+ VarHandle(size_t version, size_t scope_index, std::string name,
+ platform::Place place)
+ : version_(version),
+ scope_idx_(scope_index),
+ name_(std::move(name)),
+ place_(std::move(place)) {}
+
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index c1486b527d2e06d2b3f7e0f89458bf9a22564586..106b5f866ed5225d67082310e308984d8b3f19ed 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -63,13 +63,14 @@ ParallelExecutor::ParallelExecutor(
// Step 1. Bcast the params to devs.
// Create local scopes
if (local_scopes.empty()) {
- for (size_t i = 0; i < member_->places_.size(); ++i) {
- member_->local_scopes_.push_back(&scope->NewScope());
+ member_->local_scopes_.emplace_back(member_->global_scope_);
+ for (size_t i = 1; i < member_->places_.size(); ++i) {
+ member_->local_scopes_.emplace_back(&scope->NewScope());
}
} else {
PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size());
for (size_t i = 0; i < member_->places_.size(); ++i) {
- member_->local_scopes_.push_back(local_scopes[i]);
+ member_->local_scopes_.emplace_back(local_scopes[i]);
}
}
@@ -155,15 +156,13 @@ void ParallelExecutor::BCastParamsToGPUs(
#endif
}
-void ParallelExecutor::Run(
- const std::vector &fetch_tensors,
- const std::string &fetched_var_name,
- const std::unordered_map &feed_tensors) {
+void ParallelExecutor::Run(const std::vector &fetch_tensors,
+ const std::string &fetched_var_name) {
platform::RecordBlock b(0);
- SplitTensorToPlaces(feed_tensors);
-
// Create local scopes.
- for (auto &scope : member_->local_scopes_) {
+ for (auto it = member_->local_scopes_.rbegin();
+ it != member_->local_scopes_.rend(); ++it) {
+ auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable() =
&local_scope;
@@ -177,7 +176,7 @@ void ParallelExecutor::Run(
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
- InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
+ InitializeVariable(local_scope.Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
@@ -195,14 +194,28 @@ void ParallelExecutor::Run(
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable();
scope->DeleteScope(local_scope);
- local_scope = nullptr;
}
}
-void ParallelExecutor::SplitTensorToPlaces(
- const std::unordered_map &feed_tensors) {
- for (auto it : feed_tensors) {
- auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
+void ParallelExecutor::FeedTensorsIntoLocalScopes(
+ const std::vector> &tensors) {
+ PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), tensors.size());
+
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ auto &map = tensors[i];
+ auto *scope = member_->local_scopes_[i];
+ for (auto &pair : map) {
+ auto *trg = scope->Var(pair.first)->GetMutable();
+ trg->ShareDataWith(pair.second);
+ trg->set_lod(pair.second.lod());
+ }
+ }
+}
+
+void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes(
+ const std::unordered_map &tensors) {
+ for (auto pair : tensors) {
+ auto lod_tensors = pair.second.SplitLoDTensor(member_->places_);
PADDLE_ENFORCE_EQ(
member_->places_.size(), lod_tensors.size(),
"The number of samples of current batch is less than the count of "
@@ -211,7 +224,7 @@ void ParallelExecutor::SplitTensorToPlaces(
for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var?
auto t =
- member_->local_scopes_[j]->Var(it.first)->GetMutable();
+ member_->local_scopes_[j]->Var(pair.first)->GetMutable();
t->ShareDataWith(lod_tensors[j]);
t->set_lod(lod_tensors[j].lod());
}
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index b4f16dba858fb279ec23a8a04257dda6651148cc..303ac3bc55cfed57a03765b27d8aba581eabd1c8 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -44,16 +44,22 @@ class ParallelExecutor {
std::vector& GetLocalScopes();
+ /**
+ * Feed tensors to local scopes. The size of tensors should be equal to the
+ * size of local scopes.
+ */
+ void FeedTensorsIntoLocalScopes(
+ const std::vector>& tensors);
+
+ void FeedAndSplitTensorIntoLocalScopes(
+ const std::unordered_map& tensors);
+
void Run(const std::vector& fetch_tensors,
- const std::string& fetched_var_name,
- const std::unordered_map& feed_tensors);
+ const std::string& fetched_var_name);
void BCastParamsToGPUs(const std::unordered_set& vars) const;
private:
- void SplitTensorToPlaces(
- const std::unordered_map& feed_tensors);
-
ParallelExecutorPrivate* member_;
};
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index 8494edee6c2c714c285c45bbb4fe1d8cb1a524aa..cc45bfe9b17d767be039cc0d8d83234b6994d6c1 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -21,7 +21,7 @@ endif()
if(WITH_TESTING)
add_subdirectory(tests/book)
- if (WITH_TENSORRT)
+ if (TENSORRT_FOUND)
add_subdirectory(tensorrt)
endif()
endif()
diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc
index 718f469d38c3c6b7272c1531fae0a1e9ad2e8e3e..4a8dfd4b54227070c2143b180f8ab92753885550 100644
--- a/paddle/fluid/operators/beam_search_decode_op.cc
+++ b/paddle/fluid/operators/beam_search_decode_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_decode_op.h"
+#include
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h
index 3cc6ed310575473fae8e91a8507fb9146107e841..4cb0457d9285e20d4b6a2f9987b7fdb1c6ac157f 100644
--- a/paddle/fluid/operators/beam_search_decode_op.h
+++ b/paddle/fluid/operators/beam_search_decode_op.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
@@ -87,7 +88,7 @@ struct BeamSearchDecoder {
*/
std::vector> PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
- std::vector>& prefixes_list,
+ std::vector>* prefixes_list,
std::vector>* sentence_vector_list) const;
/**
@@ -140,7 +141,7 @@ Sentence BeamSearchDecoder::MakeSentence(const BeamNode* node) const {
template
std::vector> BeamSearchDecoder::PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
- std::vector>& prefixes_list,
+ std::vector>* prefixes_list,
std::vector>* sentence_vector_list) const {
std::vector> result;
@@ -153,7 +154,7 @@ std::vector> BeamSearchDecoder::PackTwoSteps(
// if prefixes size is 0, it means this is the first step. In this step,
// all candidate id is the start of candidate sentences.
- if (prefixes_list.empty()) {
+ if (prefixes_list->empty()) {
PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(),
cur_ids.lod().at(kSentenceLevel).back(),
"in the first step");
@@ -162,7 +163,7 @@ std::vector> BeamSearchDecoder::PackTwoSteps(
cur_ids.data()[id_idx], cur_scores.data()[id_idx])));
}
} else {
- BeamNodeVector& prefixes = prefixes_list[src_idx];
+ BeamNodeVector& prefixes = prefixes_list->at(src_idx);
SentenceVector& sentence_vector = (*sentence_vector_list)[src_idx];
PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(),
@@ -262,7 +263,7 @@ void BeamSearchDecoder::PackAllSteps(const LoDTensorArray& step_ids,
for (size_t step_id = 0; step_id < step_num; ++step_id) {
beamnode_vector_list =
PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id),
- beamnode_vector_list, &sentence_vector_list);
+ &beamnode_vector_list, &sentence_vector_list);
}
// append last beam_node to result
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
diff --git a/paddle/fluid/operators/beam_search_decode_op_test.cc b/paddle/fluid/operators/beam_search_decode_op_test.cc
index c3faf46e09bb40d01049fd9cfd79836c1d2bd5bb..36f9594969c416c694928811012baf94332bbd91 100644
--- a/paddle/fluid/operators/beam_search_decode_op_test.cc
+++ b/paddle/fluid/operators/beam_search_decode_op_test.cc
@@ -125,7 +125,7 @@ TEST(BeamSearchDecodeOp, PackTwoStepsFistStep) {
BeamSearchDecoder helper;
beamnode_vector_list = helper.PackTwoSteps(
- ids[0], scores[0], beamnode_vector_list, &sentence_vector_list);
+ ids[0], scores[0], &beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(beamnode_vector_list.size(), 2UL);
ASSERT_EQ(beamnode_vector_list[0].size(), 2UL);
ASSERT_EQ(beamnode_vector_list[1].size(), 4UL);
@@ -167,7 +167,7 @@ TEST(BeamSearchDecodeOp, PackTwoSteps) {
BeamSearchDecoder helper1;
beamnode_vector_list = helper1.PackTwoSteps(
- ids[0], scores[0], beamnode_vector_list, &sentence_vector_list);
+ ids[0], scores[0], &beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(sentence_vector_list[0].size(), 1UL);
ASSERT_EQ(sentence_vector_list[1].size(), 0UL);
diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc
index e848b1f12cb9f1ce1d37e0e0233bfc361dc35a33..fdab4e92f47c7c8f241d93268a73dcb8c2eb2dc6 100644
--- a/paddle/fluid/operators/beam_search_op.cc
+++ b/paddle/fluid/operators/beam_search_op.cc
@@ -14,7 +14,10 @@ limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
+#include
#include