diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake
index 85f40585da29bab9a107f5546e64870975f4c2d3..82437a84248fece843c3659c9422d9b579b5066f 100644
--- a/cmake/external/grpc.cmake
+++ b/cmake/external/grpc.cmake
@@ -50,6 +50,7 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
+ PATCH_COMMAND git apply ${PADDLE_SOURCE_DIR}/patches/grpc/fix_too_early_destory.patch
# NOTE(yuyang18):
# Disable -Werror, otherwise the compile will fail in MacOS.
# It seems that we cannot configure that by make command.
diff --git a/doc/fluid/design/quantization/fixed_point_quantization.md b/doc/fluid/design/quantization/fixed_point_quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..085352fc5614d693e63a2f7241e868a9649456af
--- /dev/null
+++ b/doc/fluid/design/quantization/fixed_point_quantization.md
@@ -0,0 +1,110 @@
+Fixed-point quantization uses lower bits, for example, 2-bit, 3-bit or 8-bit fixed point to represent weights and activations, which usually are in singe-precision float-point with 32 bits. The fixed-point representation has advantages in reducing memory bandwidth, lowering power consumption and computational resources as well as the model storage requirements. It is especially important for the inference in embedded-device deployment.
+
+According to some experiments, the apporach to quantize the model trained in float point directly works effectively on the large models, like the VGG model having many parameters. But the accuracy drops a lot for the small model. In order to improve the tradeoff between accuracy and latency, many quantized training apporaches are proposed.
+
+This document is to design a quantized training framework on Fluid. The first part will introduce how to quantize, The second part will describe the quantized training framework. The last part will illustrate how to calculate the quantization scale.
+
+
+### How to quantize
+
+There are many ways to quantize the float value to fixed-point value. For example:
+
+$$ r = min(max(x, a), b)$$
+$$ s = \frac{b - a}{n - 1} $$
+$$ q = \left \lfloor \frac{r - a}{s} \right \rceil $$
+
+where, $x$ is the float value to be quantized, $[a, b]$ is the quantization range, $a$ is the minimum value and $b$ is the maximal value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. If the quantization level is $k$, $n$ is $2^k$, for example, $k$ is 8 and $n$ is 256. $q$ is the quantized integer.
+
+
+The quantization we applied is parameterized by the number of quantization levels and maximum absolute value:
+
+$$ M = max(abs(x)) $$
+$$ q = \left \lfloor \frac{x}{M} * (n - 1) \right \rceil $$
+
+where, $x$ is the float value to be quantized, $M$ is maximum absolute value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. For 8 bit quantization, $n=2^{8}=256$. $q$ is the quantized integer.
+
+
+Wether the *min-max* quantization or *max-abs* quantization, they also can be represent:
+
+$q = scale * r + b$
+
+We call *min-max*, *max-abs* as the quantization arguments, also call them quantization scale or quantization range.
+
+
+How to calculate the quantization scale (or maximum absolute value) for inference will be described in the last part.
+
+
+### Training Framework
+
+#### Forward pass
+
+The forward pass is simulated quantization, see Figure 1.
+
+The training framework is as following figure.
+
+
+
+Figure 1. Forward in training with simulated quantization.
+
+
+- Firstly, both input and weight will be quantized to 8-bit integers.
+- Second, do the multiplication (or convolution) operation with integers.
+- Third, dequantize the multiplication (or convolution) results to 32-bit float point.
+- Finally, do bias-addition in float type of 32 bit. Here, the bias is not quantized.
+
+For general matrix multiplication (GEMM), quantize for $X$ and $W$:
+
+$$ X_q = \left \lfloor \frac{X}{X_m} * (n - 1) \right \rceil $$
+$$ W_q = \left \lfloor \frac{W}{W_m} * (n - 1) \right \rceil $$
+
+Do GEMM:
+
+$$ Y = X_q * W_q $$
+
+
+Dequantize $Y$:
+
+$$
+\begin{align}
+Y_{dq} &=\frac{Y}{(n - 1) * (n - 1)} * X_m * W_m \\\
+ &=\frac{X_q * W_q}{(n - 1) * (n - 1)} * X_m * W_m \\\
+ &=(\frac{X_q}{n - 1} * X_m) * (\frac{W_q}{n - 1} * W_m)
+\end{align}
+$$
+
+From these formulas, dequantization also can be moved before GEMM, do dequantization for $Xq$ and $Wq$ at first, then do GEMM. The forward workflow in training is equivalent to following framework.
+
+
+
+Figure 2. Equivalent forward in training with simulated quantization.
+
+
+We use this equivalent workflow in the training. In our desigin, there is a quantization transpiler to insert the quantization operator and the de-quantization operator in the Fluid `ProgramDesc`. Since the outputs of quantization and de-quantization operator are still in floating point, they are called faked quantization and de-quantization operator. And the training framework is called simulated quantization.
+
+#### Backward pass
+
+See Figure 3. The gradients are calculated by dequantized weights and activations. All inputs and outputs are float point with 32-bit. And in the weight updating process, the gradients will be added to the original weight, not the quantized or dequantized weights.
+
+
+
+Figure 3. Backward and weight updating in training with simulated quantization.
+
+
+So the quantization transipler will change some inputs of the corresponding backward operators.
+
+### How to calculate quantization scale
+
+There are two strategies to calculate quantization scale, we call them dynamic and static strategy. The dynamic strategy calculates the quantization scale value each iteration. The static strategy keeps the quantization scale for different inputs.
+
+For weights, we apply the dynamic strategy in the training, that is to say, the quantization scale will be recalculated during each iteration until the traning is finished.
+
+For activations, the quantization scales are estimated during training, then used in inference. There are several different ways to estimate them:
+
+
+1. Calculate the mean of maximum absolute during a window.
+2. Calculate the max of maximum absolute during a window.
+3. Calculate the running mean of maximum absolute during a window, as follows:
+
+ $$ Vt = (1 - k) * V + k * V_{t-1} $$
+
+ where, $V$ is the maximum absolute value of current batch, $Vt$ is the running mean value. $k$ is a factor, such as 0.9.
diff --git a/doc/fluid/design/quantization/quantization_backward_and_optimization.png b/doc/fluid/design/quantization/quantization_backward_and_optimization.png
new file mode 100644
index 0000000000000000000000000000000000000000..84f8235ab87cb631992b691f8e05b9c0b6c93da2
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_backward_and_optimization.png differ
diff --git a/doc/fluid/design/quantization/quantization_equivalent_forward.png b/doc/fluid/design/quantization/quantization_equivalent_forward.png
new file mode 100644
index 0000000000000000000000000000000000000000..df49c864537c047c785da12d24893e54ce0a5341
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_equivalent_forward.png differ
diff --git a/doc/fluid/design/quantization/quantization_forward.png b/doc/fluid/design/quantization/quantization_forward.png
new file mode 100644
index 0000000000000000000000000000000000000000..0913f61621bb6533bcb10bd1d18120ccaaa96cff
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_forward.png differ
diff --git a/paddle/contrib/inference/CMakeLists.txt b/paddle/contrib/inference/CMakeLists.txt
index 98c2f68a6c39ed12795bad4a905558917c0275a4..87173fc42a46c8218fbf0beb4ebf7760f69b7c24 100644
--- a/paddle/contrib/inference/CMakeLists.txt
+++ b/paddle/contrib/inference/CMakeLists.txt
@@ -45,6 +45,10 @@ endfunction(inference_api_test)
cc_library(paddle_inference_api
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
+if(NOT APPLE)
+ set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_inference_api.sym")
+ set_target_properties(paddle_inference_api PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+endif()
# Here the shared library doesn't depend on other fluid libraries, or double free will occur.
cc_library(paddle_inference_api_shared SHARED
@@ -53,8 +57,19 @@ add_dependencies(paddle_inference_api_shared ${FLUID_CORE_MODULES} ${GLOB_OP_LIB
set_target_properties(paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api)
if(NOT APPLE)
- set(LINK_FLAGS "-fPIC -fvisibility=hidden")
+ set(LINK_FLAGS "-Wl,--version-script ${CMAKE_CURRENT_SOURCE_DIR}/paddle_inference_api.map")
set_target_properties(paddle_inference_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+ FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/check_symbol.cmake
+ "execute_process(COMMAND bash -c \"${CMAKE_CURRENT_SOURCE_DIR}/check_symbol.sh"
+ " ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_inference_api.so\" RESULT_VARIABLE symbol_res)\n"
+ "if(NOT \"\${symbol_res}\" STREQUAL \"0\")\n"
+ " message(FATAL_ERROR \"Check symbol failed.\")\n"
+ "endif()\n")
+ add_custom_command(
+ OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol"
+ COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_BINARY_DIR}/check_symbol.cmake"
+ DEPENDS paddle_inference_api_shared)
+ add_custom_target(check_symbol ALL DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol")
endif()
cc_test(test_paddle_inference_api
diff --git a/paddle/contrib/inference/check_symbol.sh b/paddle/contrib/inference/check_symbol.sh
new file mode 100755
index 0000000000000000000000000000000000000000..6547ca1413649968e8a0be146915e07192a99898
--- /dev/null
+++ b/paddle/contrib/inference/check_symbol.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+lib=$1
+if [ $# -ne 1 ]; then echo "No input library"; exit -1 ; fi
+
+num_paddle_syms=$(nm -D --defined-only ${lib} | grep paddle | wc -l)
+num_google_syms=$(nm -D --defined-only ${lib} | grep google | wc -l)
+
+if [ $num_paddle_syms -le 0 ]; then echo "Have no paddle symbols"; exit -1 ; fi
+if [ $num_google_syms -ge 1 ]; then echo "Have some google symbols"; exit -1 ; fi
+
+exit 0
diff --git a/paddle/contrib/inference/demo/CMakeLists.txt b/paddle/contrib/inference/demo/CMakeLists.txt
index ecece6fe3471ad7b89c84c3e2b67af4ae9eb3c36..2d501bf0085b1bd4c39ee1a6dfaaa9622fd72ce1 100644
--- a/paddle/contrib/inference/demo/CMakeLists.txt
+++ b/paddle/contrib/inference/demo/CMakeLists.txt
@@ -13,8 +13,6 @@
# limitations under the License.
#
-inference_api_test(simple_on_word2vec ARGS test_word2vec)
-
option(WITH_INFERENCE_DEMO "Compile with Inference demo" OFF)
if(NOT WITH_INFERENCE_DEMO)
return()
diff --git a/paddle/contrib/inference/demo_ci/CMakeLists.txt b/paddle/contrib/inference/demo_ci/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..789bff7f23cd89bfaeba180efa95972cef6fc58c
--- /dev/null
+++ b/paddle/contrib/inference/demo_ci/CMakeLists.txt
@@ -0,0 +1,77 @@
+cmake_minimum_required(VERSION 3.0)
+
+project(cpp_inference_demo CXX C)
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+
+if(NOT DEFINED PADDLE_LIB)
+ message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
+endif()
+if(NOT DEFINED DEMO_NAME)
+ message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name")
+endif()
+
+option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
+option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
+option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON)
+
+if(WITH_GPU)
+ set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
+endif()
+
+include_directories("${PADDLE_LIB}")
+include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
+include_directories("${PADDLE_LIB}/third_party/install/glog/include")
+include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
+include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
+include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
+include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
+
+include_directories("${PADDLE_LIB}/third_party/boost")
+include_directories("${PADDLE_LIB}/third_party/eigen3")
+
+link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
+link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
+link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
+link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
+link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
+link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
+
+add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
+
+if(WITH_MKL)
+ include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
+ set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so
+ ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5.so)
+ set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
+ if(EXISTS ${MKLDNN_PATH})
+ include_directories("${MKLDNN_PATH}/include")
+ set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+ endif()
+else()
+ set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
+endif()
+
+if(WITH_STATIC_LIB)
+ set(DEPS
+ "-Wl,--whole-archive"
+ ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.a
+ "-Wl,--no-whole-archive"
+ ${PADDLE_LIB}/contrib/inference/libpaddle_inference_api.a)
+else()
+ # Note: libpaddle_inference_api.so must put before libpaddle_fluid.so
+ set(DEPS
+ ${PADDLE_LIB}/contrib/inference/libpaddle_inference_api.so
+ ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so)
+endif()
+set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+
+set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf snappystream snappy z
+ ${EXTERNAL_LIB})
+if(WITH_GPU)
+ set(DEPS ${DEPS} ${CUDA_LIB}/libcudart.so)
+endif()
+
+target_link_libraries(${DEMO_NAME} ${DEPS})
diff --git a/paddle/contrib/inference/demo_ci/run.sh b/paddle/contrib/inference/demo_ci/run.sh
new file mode 100755
index 0000000000000000000000000000000000000000..e3a7269af795b05c296423cb2dc92b753397c6b3
--- /dev/null
+++ b/paddle/contrib/inference/demo_ci/run.sh
@@ -0,0 +1,34 @@
+set -x
+PADDLE_ROOT=$1
+WITH_MKL=$2
+WITH_GPU=$3
+if [ $3 == "ON" ]; then
+ use_gpu_list='true false'
+else
+ use_gpu_list='false'
+fi
+
+mkdir -p build
+cd build
+
+for WITH_STATIC_LIB in false; do
+ rm -rf *
+ cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
+ -DWITH_MKL=$WITH_MKL \
+ -DDEMO_NAME=simple_on_word2vec \
+ -DWITH_GPU=$WITH_GPU \
+ -DWITH_STATIC_LIB=$WITH_STATIC_LIB
+ make
+ for use_gpu in $use_gpu_list; do
+ ./simple_on_word2vec \
+ --dirname=${PADDLE_ROOT}/build/python/paddle/fluid/tests/book/word2vec.inference.model \
+ --use_gpu=$use_gpu
+ done
+done
+if [ $? -eq 0 ]; then
+ exit 0
+else
+ echo "inference demo runs fail."
+ exit 1
+fi
+set +x
diff --git a/paddle/contrib/inference/demo/simple_on_word2vec.cc b/paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
similarity index 68%
rename from paddle/contrib/inference/demo/simple_on_word2vec.cc
rename to paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
index c253014642f39a042430992548a285cc7078a959..9713837f86d40383da946af1681e1945c84336b0 100644
--- a/paddle/contrib/inference/demo/simple_on_word2vec.cc
+++ b/paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
@@ -16,21 +16,27 @@ limitations under the License. */
* This file contains a simple demo for how to take a model for inference.
*/
+#include
#include
-#include
#include
#include
-#include "paddle/contrib/inference/paddle_inference_api.h"
+#include "contrib/inference/paddle_inference_api.h"
+#include "paddle/fluid/platform/enforce.h"
+
+DEFINE_string(dirname, "", "Directory of the inference model.");
+DEFINE_bool(use_gpu, false, "Whether use gpu.");
namespace paddle {
namespace demo {
-DEFINE_string(dirname, "", "Directory of the inference model.");
-
void Main(bool use_gpu) {
//# 1. Create PaddlePredictor with a config.
NativeConfig config;
- config.model_dir = FLAGS_dirname + "word2vec.inference.model";
+ if (FLAGS_dirname.empty()) {
+ LOG(INFO) << "Usage: ./simple_on_word2vec --dirname=path/to/your/model";
+ exit(1);
+ }
+ config.model_dir = FLAGS_dirname;
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
@@ -54,12 +60,16 @@ void Main(bool use_gpu) {
CHECK(predictor->Run(slots, &outputs));
//# 4. Get output.
- ASSERT_EQ(outputs.size(), 1UL);
- LOG(INFO) << "output buffer size: " << outputs.front().data.length();
+ PADDLE_ENFORCE(outputs.size(), 1UL);
+ // Check the output buffer size and result of each tid.
+ PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
+ float result[5] = {
+ 0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
- LOG(INFO) << static_cast(outputs.front().data.data())[i];
+ PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i],
+ result[i]);
}
}
}
@@ -68,7 +78,7 @@ void MainThreads(int num_threads, bool use_gpu) {
// Multi-threads only support on CPU
// 0. Create PaddlePredictor with a config.
NativeConfig config;
- config.model_dir = FLAGS_dirname + "word2vec.inference.model";
+ config.model_dir = FLAGS_dirname;
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
@@ -94,14 +104,17 @@ void MainThreads(int num_threads, bool use_gpu) {
CHECK(predictor->Run(inputs, &outputs));
// 4. Get output.
- ASSERT_EQ(outputs.size(), 1UL);
- LOG(INFO) << "TID: " << tid << ", "
- << "output buffer size: " << outputs.front().data.length();
+ PADDLE_ENFORCE(outputs.size(), 1UL);
+ // Check the output buffer size and result of each tid.
+ PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
+ float result[5] = {
+ 0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706};
const size_t num_elements =
outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
- LOG(INFO) << static_cast(outputs.front().data.data())[i];
+ PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i],
+ result[i]);
}
}
});
@@ -111,15 +124,18 @@ void MainThreads(int num_threads, bool use_gpu) {
}
}
-TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_cpu_1) { MainThreads(1, false /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_cpu_4) { MainThreads(4, false /*use_gpu*/); }
-
-#ifdef PADDLE_WITH_CUDA
-TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_gpu_1) { MainThreads(1, true /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_gpu_4) { MainThreads(4, true /*use_gpu*/); }
-#endif
-
} // namespace demo
} // namespace paddle
+
+int main(int argc, char** argv) {
+ google::ParseCommandLineFlags(&argc, &argv, true);
+ paddle::demo::Main(false /* use_gpu*/);
+ paddle::demo::MainThreads(1, false /* use_gpu*/);
+ paddle::demo::MainThreads(4, false /* use_gpu*/);
+ if (FLAGS_use_gpu) {
+ paddle::demo::Main(true /*use_gpu*/);
+ paddle::demo::MainThreads(1, true /*use_gpu*/);
+ paddle::demo::MainThreads(4, true /*use_gpu*/);
+ }
+ return 0;
+}
diff --git a/paddle/contrib/inference/paddle_inference_api.map b/paddle/contrib/inference/paddle_inference_api.map
new file mode 100644
index 0000000000000000000000000000000000000000..5203784dc1fcb672eb6a26d9dfd3ffbe02e08038
--- /dev/null
+++ b/paddle/contrib/inference/paddle_inference_api.map
@@ -0,0 +1,6 @@
+{
+ global:
+ *paddle*;
+ local:
+ *;
+};
diff --git a/paddle/contrib/inference/paddle_inference_api.sym b/paddle/contrib/inference/paddle_inference_api.sym
new file mode 100644
index 0000000000000000000000000000000000000000..ef2a04d788aa86b7f6a61c4af479d70d1137f374
--- /dev/null
+++ b/paddle/contrib/inference/paddle_inference_api.sym
@@ -0,0 +1 @@
+*paddle*
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index b82c2ef4082110f1621eb38d50361396511a4825..6f5d4471a97cc4efc73b9df68040ab9eccde0b1c 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -276,13 +276,22 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
}
}
- // Insert BCast Ops
- for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
- auto &to_bcast_set = bcast_var_name_set[dev_id];
- for (auto &bcast_name : to_bcast_set) {
- CreateBroadcastOp(&result, bcast_name, dev_id);
+ bool use_gpu = false;
+#ifdef PADDLE_WITH_CUDA
+ use_gpu = nccl_ctxs_ != nullptr;
+#endif
+
+ if (use_gpu ||
+ strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
+ // Insert BCast Ops
+ for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
+ auto &to_bcast_set = bcast_var_name_set[dev_id];
+ for (auto &bcast_name : to_bcast_set) {
+ CreateBroadcastOp(&result, bcast_name, dev_id);
+ }
}
}
+
/*
Dependency graph has been constructed. However, there are still data
hazards need to be handled.
@@ -412,14 +421,19 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
-
- for (auto &varname : op.InputArgumentNames()) {
- int dev_id = GetVarDeviceID(varname);
- if (dev_id != -1) {
- return dev_id;
- }
+ int op_role = boost::get(
+ op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
+ if (op_role != static_cast(framework::OpRole::kOptimize)) {
+ return -1;
}
- return -1;
+ auto param_grad = boost::get>(
+ op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
+
+ PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
+ int dev_id = GetVarDeviceID(param_grad[1]);
+ PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
+ param_grad[0]);
+ return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
index eb4e7ec52f907f9403e21ec2734d61824f51a58b..1d80bab90f513139f807b57258177c6b2ac53ac0 100644
--- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
+#include
#include
#include
#include "paddle/fluid/framework/executor.h"
@@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
}
+ std::vector fetch_data;
+ std::exception_ptr eptr;
+ try {
+ fetch_data = underlying_executor_->Run(fetch_tensors);
+ } catch (...) {
+ eptr = std::current_exception();
+ }
- auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
@@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
scope->DeleteScope(local_scope);
}
}
- return fetch_data;
+ if (eptr) {
+ std::rethrow_exception(eptr);
+ } else {
+ return fetch_data;
+ }
}
} // namespace details
} // namespace framework
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index 99b10254a7961bf7b27b256acaece573a71c4115..07097c7e75c6ce638549716cd6523f387cdefd92 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
set.clear();
};
+ // Clean run context
+ run_op_futures_.clear();
+ exception_.reset();
+
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
@@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
- std::lock_guard l(exception_mu_);
+ std::unique_lock l(exception_mu_);
if (exception_) {
+ l.unlock();
+ for (auto &run_op_future : run_op_futures_) {
+ run_op_future.wait();
+ }
+ l.lock();
std::exception *exp = exception_.get();
if (dynamic_cast(exp)) {
auto e = *static_cast(exp);
- exception_.reset();
throw e;
} else if (dynamic_cast(exp)) {
auto e = *static_cast(exp);
- exception_.reset();
throw e;
} else {
LOG(FATAL) << "Unknown exception.";
@@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp(
}
};
if (pool_) {
- pool_->enqueue(op_run);
+ run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index c69e0487e2e503a0d445300aa2fd6bb9c30b06c9..09973b7a72881464ad9e7776d4aad3d2261a118d 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -15,6 +15,7 @@
#pragma once
#include
+#include
#include
#include
#include
@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
+ // use std::list because clear(), push_back, and for_each are O(1)
+ std::list> run_op_futures_;
};
} // namespace details
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index b53a6f43fbd1f23e69d23ad0fcc54d5c25d352a3..58be61362cabf22a3543af364f1b0bd180df826a 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -45,6 +45,7 @@ class ParallelExecutorPrivate {
#endif
bool own_local_scope_;
bool use_cuda_;
+ bool use_all_reduce_;
};
std::vector &ParallelExecutor::GetLocalScopes() {
@@ -62,6 +63,14 @@ ParallelExecutor::ParallelExecutor(
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
member_->use_cuda_ = exec_strategy.use_cuda_;
+ member_->use_all_reduce_ =
+ build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
+
+ if (!member_->use_all_reduce_) {
+ PADDLE_ENFORCE(places.size() > 1,
+ "If you set build_strategy.reduce with 'Reduce',"
+ "the number of places must be greater than 1.");
+ }
// Step 1. Bcast the params to devs.
// Create local scopes
@@ -95,7 +104,7 @@ ParallelExecutor::ParallelExecutor(
}
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
- BCastParamsToGPUs(bcast_vars);
+ BCastParamsToDevs(bcast_vars);
}
// Startup Program has been run. All local scopes has correct parameters.
@@ -117,7 +126,7 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
#else
- PADDLE_THROW("Not compiled with CUDA");
+ PADDLE_THROW("Not compiled with CUDA.");
#endif
}
@@ -131,9 +140,9 @@ ParallelExecutor::ParallelExecutor(
member_->places_, std::move(member_->executor_)));
}
-void ParallelExecutor::BCastParamsToGPUs(
+void ParallelExecutor::BCastParamsToDevs(
const std::unordered_set &vars) const {
- // the the initializing bcast, all vars would be bcast from device(0),
+ // the initializing bcast, all vars would be bcast from device(0),
// otherwise
// bcast from the specified device.
bool initializing = builder_.get() == nullptr ? true : false;
@@ -202,12 +211,20 @@ void ParallelExecutor::BCastParamsToGPUs(
#endif
} else {
platform::CPUPlace cpu;
- for (size_t i = 1; i < member_->places_.size(); ++i) {
+ for (size_t i = 0; i < member_->places_.size(); ++i) {
+ if ((initializing && i == 0) ||
+ (!initializing && static_cast(i) == var_dev_id))
+ continue;
+
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable();
- t->Resize(dims);
- t->mutable_data(cpu, main_tensor.type());
- paddle::framework::TensorCopy(main_tensor, cpu, t);
+ if (member_->use_all_reduce_ || member_->use_cuda_) {
+ t->Resize(dims);
+ t->mutable_data(cpu, main_tensor.type());
+ paddle::framework::TensorCopy(main_tensor, cpu, t);
+ } else {
+ t->ShareDataWith(main_tensor);
+ }
}
}
}
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index 058f83f07c26224e3180d140630c08a24c40cd80..6985b6540690c6218bcee51ba0e69f3d34812bfc 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -66,7 +66,7 @@ class ParallelExecutor {
void Run(const std::vector &fetch_tensors,
const std::string &fetched_var_name);
- void BCastParamsToGPUs(const std::unordered_set &vars) const;
+ void BCastParamsToDevs(const std::unordered_set &vars) const;
private:
ParallelExecutorPrivate *member_;
diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h
index 6c4432cb7a70853e19460b1980d621c02caed970..a8d04feb42456607159bcbede0574fe90dfe995c 100644
--- a/paddle/fluid/framework/reader.h
+++ b/paddle/fluid/framework/reader.h
@@ -29,11 +29,11 @@ enum ReaderStatus { kRunning, kStopped };
class ReaderBase {
public:
- void ReadNext(std::vector* out);
+ virtual void ReadNext(std::vector* out);
- void Shutdown();
+ virtual void Shutdown();
- void Start();
+ virtual void Start();
// Return the readers which are the end of decorating chain. Basically
// they are readers just before read op.
@@ -42,7 +42,7 @@ class ReaderBase {
virtual ~ReaderBase();
protected:
- virtual void ReadNextImpl(std::vector* out) = 0;
+ virtual void ReadNextImpl(std::vector* out) {}
virtual void ShutdownImpl() {}
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index 1895aea7f98cb1ad12b2ce16545339252349ea37..b1c33c3415f49f9b1160655034350087432d0cb0 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -13,6 +13,12 @@ endif()
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api)
+if(NOT APPLE)
+ # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
+ set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym")
+ set_target_properties(paddle_fluid PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+endif()
+
# Create shared library
cc_library(paddle_fluid_shared SHARED
SRCS io.cc
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc
index d09bf3ed161703b0cf273522921e157c7360a0bc..bd24e8a7d9c20b8cd9c4e41a76ffc33a004a9a69 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph.cc
@@ -90,6 +90,20 @@ std::string DataFlowGraph::DotString() const {
return dot.Build();
}
+std::string DataFlowGraph::HumanReadableInfo(bool show_values,
+ bool show_functions) const {
+ std::stringstream values, functions;
+ for (auto &n : nodes.nodes()) {
+ if (show_values && n->IsValue()) {
+ values << n->repr() << "\n";
+ }
+ if (show_functions && n->IsFunction()) {
+ functions << n->repr() << "\n";
+ }
+ }
+ return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str();
+}
+
//
// NodesBFSIterator
//
@@ -146,7 +160,7 @@ bool GraphTraits::NodesBFSIterator::operator==(
if ((!queue_.empty()) && (!other.queue_.empty())) {
return queue_.front() == other.queue_.front() &&
visited_.size() == other.visited_.size(); // here need to check the
- // equality of queue and
+ // equality of queue and
// visited. Just a light but week implementation.
}
return false;
@@ -208,6 +222,76 @@ Node *GraphTraits::NodesDFSIterator::operator->() {
return stack_.top();
}
+GraphTraits::NodesTSIterator::NodesTSIterator(
+ const std::vector &source) {
+ PADDLE_ENFORCE(!source.empty(),
+ "Start points of topological sorting should not be empty!");
+ std::unordered_set visited;
+ std::unordered_set to_visit{source.begin(), source.end()};
+
+ std::vector inlink_visited;
+ while (!to_visit.empty()) {
+ std::vector queue(to_visit.begin(), to_visit.end());
+ for (auto *p : queue) {
+ inlink_visited.clear();
+
+ std::copy_if(p->inlinks.begin(), p->inlinks.end(),
+ std::back_inserter(inlink_visited),
+ [&](Node *x) { return visited.count(x); });
+
+ if (inlink_visited.size() == p->inlinks.size()) {
+ sorted_.push_back(p);
+ for (auto *_ : p->outlinks) {
+ if (!visited.count(_)) {
+ to_visit.insert(_);
+ }
+ }
+
+ to_visit.erase(p);
+ visited.insert(p);
+ }
+ }
+ }
+}
+
+GraphTraits::NodesTSIterator::NodesTSIterator(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other)
+ : sorted_(other.sorted_), cursor_(other.cursor_) {}
+
+Node &GraphTraits::NodesTSIterator::operator*() {
+ PADDLE_ENFORCE_LT(cursor_, sorted_.size());
+ return *sorted_[cursor_];
+}
+
+paddle::inference::analysis::GraphTraits::NodesTSIterator
+ &GraphTraits::NodesTSIterator::operator++() {
+ if (++cursor_ >= sorted_.size()) {
+ sorted_.clear();
+ cursor_ = 0;
+ }
+ return *this;
+}
+paddle::inference::analysis::GraphTraits::NodesTSIterator &
+GraphTraits::NodesTSIterator::operator=(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other) {
+ cursor_ = other.cursor_;
+ sorted_ = other.sorted_;
+ return *this;
+}
+
+bool GraphTraits::NodesTSIterator::operator==(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other) {
+ return sorted_ == other.sorted_ && cursor_ == other.cursor_;
+}
+
+Node *GraphTraits::NodesTSIterator::operator->() {
+ PADDLE_ENFORCE_LT(cursor_, sorted_.size());
+ return sorted_[cursor_];
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h
index a4fefc83e0c551d52bec87299bcbc966e7a2dbf7..5dd914d1971bfb5bcc0b1db41d73e2b67120bc06 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.h
+++ b/paddle/fluid/inference/analysis/data_flow_graph.h
@@ -48,6 +48,9 @@ struct DataFlowGraph {
// Output a DOT graph file for debug.
std::string DotString() const;
+ std::string HumanReadableInfo(bool show_values = true,
+ bool show_functions = true) const;
+
private:
// Remove duplicate edges and so on.
void Clean();
@@ -107,6 +110,32 @@ struct GraphTraits {
std::unordered_set visited_;
};
+ // Topological sorting iterator on nodes.
+ struct NodesTSIterator
+ : public std::iterator {
+ NodesTSIterator() = default;
+ explicit NodesTSIterator(const std::vector &source);
+ NodesTSIterator(NodesTSIterator &&other)
+ : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
+ other.cursor_ = 0;
+ }
+ NodesTSIterator(const NodesTSIterator &other);
+
+ Node &operator*();
+ NodesTSIterator &operator++();
+ // TODO(Superjomn) current implementation just compare the first
+ // element, need to compare the graph and all the elements in the queue and
+ // set.
+ NodesTSIterator &operator=(const NodesTSIterator &other);
+ bool operator==(const NodesTSIterator &other);
+ bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
+ Node *operator->();
+
+ private:
+ std::vector sorted_;
+ int cursor_{0};
+ };
+
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
// default use BFS to visit the nodes.
@@ -119,17 +148,24 @@ struct GraphTraits {
iterator_range nodes_in_DFS() {
return iterator_range(nodes_dfs_begin(), nodes_dfs_end());
}
+ iterator_range nodes_in_TS() {
+ return iterator_range(nodes_ts_begin(), nodes_ts_end());
+ }
private:
NodesBFSIterator nodes_bfs_begin() {
return NodesBFSIterator(graph_->inputs);
}
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
+
NodesDFSIterator nodes_dfs_begin() {
return NodesDFSIterator(graph_->inputs);
}
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
+ NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); }
+ NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
+
private:
DataFlowGraph *graph_;
};
diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
index 9d7cceeb65888b8ba3fdf39e88fc2877abd82d11..7912f8d7f17ae3c79e8f73f36b7095fd52c9ac86 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
@@ -24,11 +24,11 @@ TEST(DataFlowGraph, BFS) {
auto dfg = ProgramDescToDFG(desc);
dfg.Build();
- for (auto* in : dfg.inputs) {
+ for (auto *in : dfg.inputs) {
LOG(INFO) << "inputs: " << in->name() << " "
<< static_cast(in->type());
}
- for (auto* out : dfg.outputs) {
+ for (auto *out : dfg.outputs) {
LOG(INFO) << "outputs: " << out->name() << " "
<< static_cast(out->type());
}
@@ -57,6 +57,71 @@ TEST(DataFlowGraph, DFS) {
ASSERT_EQ(count, dfg.nodes.size());
}
+// Topological sorting.
+/*
+ * Graph topology
+ * inputs: 0, 1, 2
+ * 0 -> 4
+ * 0 -> 5
+ * 1 -> 6
+ * 2 -> 7
+ * 4 -> 5
+ * 4 -> 7
+ * 4 -> 3
+ * 7 -> 3
+ */
+TEST(DataFlowGraph, TS) {
+ DataFlowGraph graph;
+
+ for (int i = 0; i < 8; i++) {
+ auto *node = graph.nodes.Create(Node::Type::kValue);
+ node->SetName("node-" + std::to_string(i));
+ }
+
+ auto add_link = [&](int i, int j) {
+ Node *source = graph.nodes.GetMutable(i);
+ Node *target = graph.nodes.GetMutable(j);
+ target->inlinks.push_back(source);
+ source->outlinks.push_back(target);
+ };
+
+ graph.inputs.push_back(graph.nodes.GetMutable(0));
+ graph.inputs.push_back(graph.nodes.GetMutable(1));
+ graph.inputs.push_back(graph.nodes.GetMutable(2));
+
+ add_link(0, 4);
+ add_link(0, 5);
+ add_link(1, 6);
+ add_link(2, 7);
+ add_link(4, 5);
+ add_link(4, 7);
+ add_link(4, 3);
+ add_link(7, 3);
+
+ auto its = GraphTraits(&graph).nodes_in_TS();
+ std::vector sorted_ids;
+ for (auto it = its.begin(); it != its.end(); ++it) {
+ LOG(INFO) << it->name();
+ sorted_ids.push_back(it->id());
+ }
+
+ // Assert a occurs prior to b in the sorted_ids.
+ auto assert_positive_sequence_pair = [&](int a, int b) {
+ auto a_offset = std::find(sorted_ids.begin(), sorted_ids.end(), a);
+ auto b_offset = std::find(sorted_ids.begin(), sorted_ids.end(), b);
+ ASSERT_LT(a_offset, b_offset);
+ };
+
+ assert_positive_sequence_pair(2, 7);
+ assert_positive_sequence_pair(7, 3);
+ assert_positive_sequence_pair(4, 3);
+ assert_positive_sequence_pair(0, 4);
+ assert_positive_sequence_pair(0, 5);
+ assert_positive_sequence_pair(1, 6);
+ assert_positive_sequence_pair(4, 5);
+ assert_positive_sequence_pair(4, 7);
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/paddle_fluid.sym b/paddle/fluid/inference/paddle_fluid.sym
new file mode 100644
index 0000000000000000000000000000000000000000..ef2a04d788aa86b7f6a61c4af479d70d1137f374
--- /dev/null
+++ b/paddle/fluid/inference/paddle_fluid.sym
@@ -0,0 +1 @@
+*paddle*
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index ab1d2143330fb8cbfd535758a83bc71de939c4e0..d265150f25419509126028e36e629aee3ee6bd0f 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -259,12 +259,15 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
+op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
+op_library(unsqueeze_op DEPS reshape_op)
+op_library(squeeze_op DEPS reshape_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc
index 6b06913d1c83f4534238ac3dd22ac4035c0f0fbf..5bfa1aaa696d5cbe8bdcb94d708746259952740f 100644
--- a/paddle/fluid/operators/conv_mkldnn_op.cc
+++ b/paddle/fluid/operators/conv_mkldnn_op.cc
@@ -29,6 +29,79 @@ using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
+class ConvMKLDNNHandler : public platform::MKLDNNHandler {
+ public:
+ ConvMKLDNNHandler(
+ std::shared_ptr conv_pd,
+ const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
+ const std::string& base_key)
+ : platform::MKLDNNHandler(dev_ctx, engine, base_key) {
+ conv_pd_ = conv_pd;
+ }
+
+ std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) {
+ return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
+ "@dst_mem_p");
+ }
+
+ std::shared_ptr AcquireSrcMemoryFromPrimitive(
+ const std::shared_ptr user_memory_p,
+ std::vector& pipeline) {
+ auto src_pd = conv_pd_->src_primitive_desc();
+ auto user_pd = user_memory_p->get_primitive_desc();
+ return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
+ pipeline);
+ }
+
+ std::shared_ptr AcquireWeightsMemoryFromPrimitive(
+ const std::shared_ptr user_weights_memory_p,
+ std::vector& pipeline) {
+ auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
+ auto weights_pd = conv_pd_->weights_primitive_desc();
+ return this->AcquireMemory(weights_pd, user_weights_pd,
+ user_weights_memory_p, "@weights_mem_p",
+ pipeline);
+ }
+
+ std::shared_ptr AcquireConvolution(
+ std::shared_ptr src_memory_p,
+ std::shared_ptr weights_memory_p,
+ std::shared_ptr dst_memory_p) {
+ auto prim_key = key_ + "@conv_p";
+ auto prim_desc_key = key_ + "@conv_pd";
+ auto conv_p = std::static_pointer_cast(
+ dev_ctx_.GetBlob(prim_key));
+ PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
+ "Fail to find convolution primitive in device context");
+ if (conv_p == nullptr) {
+ conv_p = std::make_shared(
+ *conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
+ *(dst_memory_p.get()));
+
+ dev_ctx_.SetBlob(prim_key, conv_p);
+ } else {
+ is_reusing_ = true;
+ }
+ return conv_p;
+ }
+
+ // Generate keys for storing/retriving primitives for this operator
+ // TODO(jczaja): Make hashing function more optimial
+ static std::string GetHash(memory::dims& input_dims,
+ memory::dims& weights_dims,
+ std::vector& strides,
+ std::vector& paddings,
+ std::vector& dilations, int groups,
+ const std::string& suffix) {
+ return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
+ dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
+ suffix;
+ }
+
+ private:
+ std::shared_ptr conv_pd_;
+};
+
template
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
public:
@@ -36,10 +109,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
- // Get unique name for index
- const std::string key = ctx.op().Output("Output");
- const std::string key_conv_pd = key + "@conv_pd";
-
auto& dev_ctx =
ctx.template device_context();
const auto& mkldnn_engine = dev_ctx.GetEngine();
@@ -80,68 +149,62 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
paddle::framework::vectorize2int(filter->dims());
std::vector dst_tz = paddle::framework::vectorize2int(output->dims());
- // create mkldnn memory from input tensors (data/weights)
- auto user_src_memory = memory(
- {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
- to_void_cast(input_data));
- auto user_weights_memory =
- memory({{{weights_tz}, memory::data_type::f32, filter->format()},
- mkldnn_engine},
- to_void_cast(filter_data));
+ // Get unique name for storing MKLDNN primitives
+ const std::string key = ConvMKLDNNHandler::GetHash(
+ src_tz, weights_tz, strides, paddings, dilations, groups,
+ ctx.op().Output("Output"));
+ const std::string key_conv_pd = key + "@conv_pd";
+
+ std::vector pipeline;
+
+ auto user_src_md = platform::MKLDNNMemDesc(
+ {src_tz}, platform::MKLDNNGetDataType(), input->format());
+ auto user_weights_md = platform::MKLDNNMemDesc(
+ {weights_tz}, platform::MKLDNNGetDataType(), filter->format());
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
- auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
- memory::format::any);
+ auto src_md = platform::MKLDNNMemDesc(
+ src_tz, platform::MKLDNNGetDataType(), memory::format::any);
auto weights_md = platform::MKLDNNMemDesc(
- weights_tz, memory::data_type::f32, memory::format::any);
- auto dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
- memory::format::any);
+ weights_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ auto dst_md = platform::MKLDNNMemDesc(
+ dst_tz, platform::MKLDNNGetDataType(), memory::format::any);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine);
+ // Save conv_pd/src_memory/weights_memory for backward pass
+ dev_ctx.SetBlob(key_conv_pd, conv_pd);
- // create reorder primitive if the input format is not the preferred one
- auto src_memory = user_src_memory;
- primitive reorder_src;
- bool is_src_reordered = false;
- if (memory::primitive_desc(conv_pd->src_primitive_desc()) !=
- user_src_memory.get_primitive_desc()) {
- src_memory = memory(conv_pd->src_primitive_desc());
- reorder_src = reorder(user_src_memory, src_memory);
- is_src_reordered = true;
- }
- auto weights_memory = user_weights_memory;
- primitive reorder_weights;
- bool is_weights_reordered = false;
- if (memory::primitive_desc(conv_pd->weights_primitive_desc()) !=
- user_weights_memory.get_primitive_desc()) {
- weights_memory = memory(conv_pd->weights_primitive_desc());
- reorder_weights = reorder(user_weights_memory, weights_memory);
- is_weights_reordered = true;
- }
+ ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
- // create memory primitive for conv dst
- auto dst_memory = memory(conv_pd->dst_primitive_desc(), output_data);
+ // create mkldnn memory from input tensors (data/weights)
+ auto user_src_memory_p =
+ handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data));
+ auto user_weights_memory_p = handler.AcquireWeightsMemory(
+ user_weights_md, to_void_cast(filter_data));
+
+ // create reorder primitive if the input format is not the preferred one
+ auto src_memory_p =
+ handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
+ auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
+ user_weights_memory_p, pipeline);
+ auto dst_memory_p =
+ handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data));
// create convolution op primitive
- auto conv_prim = conv_fwd(*conv_pd, src_memory, weights_memory, dst_memory);
+ auto conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
+ dst_memory_p);
// push primitive to stream and wait until it's executed
- std::vector pipeline;
- if (is_src_reordered) pipeline.push_back(reorder_src);
- if (is_weights_reordered) pipeline.push_back(reorder_weights);
- pipeline.push_back(conv_prim);
+ pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
- // Save conv_pd/src_memory/weights_memory for backward pass
- dev_ctx.SetBlob(key_conv_pd, conv_pd);
-
output->set_layout(DataLayout::kMKLDNN);
- output->set_format(GetMKLDNNFormat(dst_memory));
+ output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
private:
@@ -197,13 +260,10 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel {
if (!input_grad && !filter_grad) return;
- // Get an unique name from "argument" name of "Output" variable
- // This name will be used as key when saving info into device context
- const std::string key = ctx.op().Input("Output");
- const std::string key_conv_pd = key + "@conv_pd";
-
std::vector strides = ctx.Attr>("strides");
std::vector paddings = ctx.Attr>("paddings");
+ std::vector dilations = ctx.Attr>("dilations");
+ int groups = ctx.Attr("groups");
const T* input_data = input->data();
const T* filter_data = filter->data();
@@ -223,6 +283,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel {
paddle::framework::vectorize2int(filter->dims());
std::vector dst_tz = paddle::framework::vectorize2int(output->dims());
+ // Get an unique name from "argument" name of "Output" variable
+ // This name will be used as key when saving info into device context
+ const std::string key =
+ ConvMKLDNNHandler::GetHash(src_tz, weights_tz, strides, paddings,
+ dilations, groups, ctx.op().Input("Output"));
+
+ const std::string key_conv_pd = key + "@conv_pd";
+
// create mkldnn memory from input tensors (input/weights/output_grad)
auto user_src_memory = memory(
{{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt
index 6d296ff7bf14de9175dc589dfa8b46c534127ca1..a44d84cd7b99107fef09a6b4dfa60172fabd718b 100644
--- a/paddle/fluid/operators/detection/CMakeLists.txt
+++ b/paddle/fluid/operators/detection/CMakeLists.txt
@@ -27,7 +27,8 @@ anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
target_assign_op.cu)
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
- polygon_box_transform_op.cu)
+polygon_box_transform_op.cu)
+detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
# Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc
index 4e35c38e4e03d4d0f00601812fdc4803519b89ae..b5cb6a724c095eb849f3a184f13843e1a0cca92f 100644
--- a/paddle/fluid/operators/detection/prior_box_op.cc
+++ b/paddle/fluid/operators/detection/prior_box_op.cc
@@ -149,6 +149,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"(float) "
"Prior boxes center offset.")
.SetDefault(0.5);
+ AddAttr(
+ "min_max_aspect_ratios_order",
+ "(bool) If set True, the output prior box is in order of"
+ "[min, max, aspect_ratios], which is consistent with Caffe."
+ "Please note, this order affects the weights order of convolution layer"
+ "followed by and does not affect the final detection results.")
+ .SetDefault(false);
AddComment(R"DOC(
Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
diff --git a/paddle/fluid/operators/detection/prior_box_op.cu b/paddle/fluid/operators/detection/prior_box_op.cu
index f67e6ca91c0852b5a3be35d23246884d1157caa4..1ea8cfc1d2af8cc6c332768a467cdcd4c0166319 100644
--- a/paddle/fluid/operators/detection/prior_box_op.cu
+++ b/paddle/fluid/operators/detection/prior_box_op.cu
@@ -28,8 +28,8 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
const int im_width, const int as_num,
const T offset, const T step_width,
const T step_height, const T* min_sizes,
- const T* max_sizes, const int min_num,
- bool is_clip) {
+ const T* max_sizes, const int min_num, bool is_clip,
+ bool min_max_aspect_ratios_order) {
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
int box_num = height * width * num_priors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
@@ -44,14 +44,28 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
T min_size = min_sizes[m];
if (max_sizes) {
int s = p % (as_num + 1);
- if (s < as_num) {
- T ar = aspect_ratios[s];
- bw = min_size * sqrt(ar) / 2.;
- bh = min_size / sqrt(ar) / 2.;
+ if (!min_max_aspect_ratios_order) {
+ if (s < as_num) {
+ T ar = aspect_ratios[s];
+ bw = min_size * sqrt(ar) / 2.;
+ bh = min_size / sqrt(ar) / 2.;
+ } else {
+ T max_size = max_sizes[m];
+ bw = sqrt(min_size * max_size) / 2.;
+ bh = bw;
+ }
} else {
- T max_size = max_sizes[m];
- bw = sqrt(min_size * max_size) / 2.;
- bh = bw;
+ if (s == 0) {
+ bw = bh = min_size / 2.;
+ } else if (s == 1) {
+ T max_size = max_sizes[m];
+ bw = sqrt(min_size * max_size) / 2.;
+ bh = bw;
+ } else {
+ T ar = aspect_ratios[s - 1];
+ bw = min_size * sqrt(ar) / 2.;
+ bh = min_size / sqrt(ar) / 2.;
+ }
}
} else {
int s = p % as_num;
@@ -94,6 +108,8 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel {
auto variances = ctx.Attr>("variances");
auto flip = ctx.Attr("flip");
auto clip = ctx.Attr("clip");
+ auto min_max_aspect_ratios_order =
+ ctx.Attr("min_max_aspect_ratios_order");
std::vector aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
@@ -149,7 +165,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel {
GenPriorBox<<>>(
boxes->data(), r.data(), height, width, im_height, im_width,
aspect_ratios.size(), offset, step_width, step_height, min.data(),
- max_data, min_num, clip);
+ max_data, min_num, clip, min_max_aspect_ratios_order);
framework::Tensor v;
framework::TensorFromVector(variances, ctx.device_context(), &v);
diff --git a/paddle/fluid/operators/detection/prior_box_op.h b/paddle/fluid/operators/detection/prior_box_op.h
index 1c62fd8d2c4d4e4deba4ca6442efbaff83e36c35..4e226abbb51c271502f0ca5419d488643b5a1a82 100644
--- a/paddle/fluid/operators/detection/prior_box_op.h
+++ b/paddle/fluid/operators/detection/prior_box_op.h
@@ -68,6 +68,8 @@ class PriorBoxOpKernel : public framework::OpKernel {
auto variances = ctx.Attr>("variances");
auto flip = ctx.Attr("flip");
auto clip = ctx.Attr("clip");
+ auto min_max_aspect_ratios_order =
+ ctx.Attr("min_max_aspect_ratios_order");
std::vector aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
@@ -108,26 +110,59 @@ class PriorBoxOpKernel : public framework::OpKernel {
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
- // priors with different aspect ratios
- for (size_t r = 0; r < aspect_ratios.size(); ++r) {
- float ar = aspect_ratios[r];
- box_width = min_size * sqrt(ar) / 2.;
- box_height = min_size / sqrt(ar) / 2.;
- e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
- e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
- e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
- e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
- idx++;
- }
- if (max_sizes.size() > 0) {
- auto max_size = max_sizes[s];
- // square prior with size sqrt(minSize * maxSize)
- box_width = box_height = sqrt(min_size * max_size) / 2.;
+ if (min_max_aspect_ratios_order) {
+ box_width = box_height = min_size / 2.;
e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
idx++;
+ if (max_sizes.size() > 0) {
+ auto max_size = max_sizes[s];
+ // square prior with size sqrt(minSize * maxSize)
+ box_width = box_height = sqrt(min_size * max_size) / 2.;
+ e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
+ e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
+ e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
+ e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
+ idx++;
+ }
+ // priors with different aspect ratios
+ for (size_t r = 0; r < aspect_ratios.size(); ++r) {
+ float ar = aspect_ratios[r];
+ if (fabs(ar - 1.) < 1e-6) {
+ continue;
+ }
+ box_width = min_size * sqrt(ar) / 2.;
+ box_height = min_size / sqrt(ar) / 2.;
+ e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
+ e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
+ e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
+ e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
+ idx++;
+ }
+ } else {
+ // priors with different aspect ratios
+ for (size_t r = 0; r < aspect_ratios.size(); ++r) {
+ float ar = aspect_ratios[r];
+ box_width = min_size * sqrt(ar) / 2.;
+ box_height = min_size / sqrt(ar) / 2.;
+ e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
+ e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
+ e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
+ e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
+ idx++;
+ }
+ if (max_sizes.size() > 0) {
+ auto max_size = max_sizes[s];
+ // square prior with size sqrt(minSize * maxSize)
+ box_width = box_height = sqrt(min_size * max_size) / 2.;
+ e_boxes(h, w, idx, 0) = (center_x - box_width) / img_width;
+ e_boxes(h, w, idx, 1) = (center_y - box_height) / img_height;
+ e_boxes(h, w, idx, 2) = (center_x + box_width) / img_width;
+ e_boxes(h, w, idx, 3) = (center_y + box_height) / img_height;
+ idx++;
+ }
}
}
}
diff --git a/paddle/fluid/operators/detection/rpn_target_assign_op.cc b/paddle/fluid/operators/detection/rpn_target_assign_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..9a1643d5b35c067ba9064286bab32019fb34fbe8
--- /dev/null
+++ b/paddle/fluid/operators/detection/rpn_target_assign_op.cc
@@ -0,0 +1,283 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/math_function.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
+template
+using EigenMatrix = framework::EigenMatrix;
+
+class RpnTargetAssignOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("DistMat"),
+ "Input(DistMat) of RpnTargetAssignOp should not be null");
+
+ PADDLE_ENFORCE(
+ ctx->HasOutput("LocationIndex"),
+ "Output(LocationIndex) of RpnTargetAssignOp should not be null");
+ PADDLE_ENFORCE(
+ ctx->HasOutput("ScoreIndex"),
+ "Output(ScoreIndex) of RpnTargetAssignOp should not be null");
+ PADDLE_ENFORCE(
+ ctx->HasOutput("TargetLabel"),
+ "Output(TargetLabel) of RpnTargetAssignOp should not be null");
+
+ auto in_dims = ctx->GetInputDim("DistMat");
+ PADDLE_ENFORCE_EQ(in_dims.size(), 2,
+ "The rank of Input(DistMat) must be 2.");
+ }
+};
+
+template
+class RpnTargetAssignKernel : public framework::OpKernel {
+ public:
+ void ScoreAssign(const T* dist_data, const Tensor& anchor_to_gt_max,
+ const int row, const int col, const float pos_threshold,
+ const float neg_threshold, int64_t* target_label_data,
+ std::vector* fg_inds, std::vector* bg_inds) const {
+ int fg_offset = fg_inds->size();
+ int bg_offset = bg_inds->size();
+ for (int64_t i = 0; i < row; ++i) {
+ const T* v = dist_data + i * col;
+ T max_dist = *std::max_element(v, v + col);
+ for (int64_t j = 0; j < col; ++j) {
+ T val = dist_data[i * col + j];
+ if (val == max_dist) target_label_data[j] = 1;
+ }
+ }
+
+ // Pick the fg/bg and count the number
+ for (int64_t j = 0; j < col; ++j) {
+ if (anchor_to_gt_max.data()[j] > pos_threshold) {
+ target_label_data[j] = 1;
+ } else if (anchor_to_gt_max.data()[j] < neg_threshold) {
+ target_label_data[j] = 0;
+ }
+ if (target_label_data[j] == 1) {
+ fg_inds->push_back(fg_offset + j);
+ } else if (target_label_data[j] == 0) {
+ bg_inds->push_back(bg_offset + j);
+ }
+ }
+ }
+
+ void ReservoirSampling(const int num, const int offset,
+ std::minstd_rand engine,
+ std::vector* inds) const {
+ std::uniform_real_distribution uniform(0, 1);
+ const int64_t size = static_cast(inds->size());
+ if (size > num) {
+ for (int64_t i = num; i < size; ++i) {
+ int rng_ind = std::floor(uniform(engine) * i);
+ if (rng_ind < num)
+ std::iter_swap(inds->begin() + rng_ind + offset,
+ inds->begin() + i + offset);
+ }
+ }
+ }
+
+ void RpnTargetAssign(const framework::ExecutionContext& ctx,
+ const Tensor& dist, const float pos_threshold,
+ const float neg_threshold, const int rpn_batch_size,
+ const int fg_num, std::minstd_rand engine,
+ std::vector* fg_inds, std::vector* bg_inds,
+ int64_t* target_label_data) const {
+ auto* dist_data = dist.data();
+ int64_t row = dist.dims()[0];
+ int64_t col = dist.dims()[1];
+ int fg_offset = fg_inds->size();
+ int bg_offset = bg_inds->size();
+
+ // Calculate the max IoU between anchors and gt boxes
+ Tensor anchor_to_gt_max;
+ anchor_to_gt_max.mutable_data(
+ framework::make_ddim({static_cast(col), 1}),
+ platform::CPUPlace());
+ auto& place = *ctx.template device_context()
+ .eigen_device();
+ auto x = EigenMatrix::From(dist);
+ auto x_col_max = EigenMatrix::From(anchor_to_gt_max);
+ x_col_max.device(place) =
+ x.maximum(Eigen::DSizes(0))
+ .reshape(Eigen::DSizes(static_cast(col), 1));
+ // Follow the Faster RCNN's implementation
+ ScoreAssign(dist_data, anchor_to_gt_max, row, col, pos_threshold,
+ neg_threshold, target_label_data, fg_inds, bg_inds);
+ // Reservoir Sampling
+ ReservoirSampling(fg_num, fg_offset, engine, fg_inds);
+ int bg_num = rpn_batch_size - fg_inds->size();
+ ReservoirSampling(bg_num, bg_offset, engine, bg_inds);
+ }
+
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto* dist = context.Input("DistMat");
+ auto* loc_index = context.Output("LocationIndex");
+ auto* score_index = context.Output("ScoreIndex");
+ auto* tgt_lbl = context.Output("TargetLabel");
+
+ auto col = dist->dims()[1];
+ int64_t n = dist->lod().size() == 0UL
+ ? 1
+ : static_cast(dist->lod().back().size() - 1);
+ if (dist->lod().size()) {
+ PADDLE_ENFORCE_EQ(dist->lod().size(), 1UL,
+ "Only support 1 level of LoD.");
+ }
+ int rpn_batch_size = context.Attr("rpn_batch_size_per_im");
+ float pos_threshold = context.Attr("rpn_positive_overlap");
+ float neg_threshold = context.Attr("rpn_negative_overlap");
+ float fg_fraction = context.Attr("fg_fraction");
+
+ int fg_num = static_cast(rpn_batch_size * fg_fraction);
+
+ int64_t* target_label_data =
+ tgt_lbl->mutable_data({n * col, 1}, context.GetPlace());
+
+ auto& dev_ctx = context.device_context();
+ math::SetConstant iset;
+ iset(dev_ctx, tgt_lbl, static_cast(-1));
+
+ std::vector fg_inds;
+ std::vector bg_inds;
+ std::random_device rnd;
+ std::minstd_rand engine;
+ int seed =
+ context.Attr("fix_seed") ? context.Attr("seed") : rnd();
+ engine.seed(seed);
+
+ if (n == 1) {
+ RpnTargetAssign(context, *dist, pos_threshold, neg_threshold,
+ rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
+ target_label_data);
+ } else {
+ auto lod = dist->lod().back();
+ for (size_t i = 0; i < lod.size() - 1; ++i) {
+ Tensor one_ins = dist->Slice(lod[i], lod[i + 1]);
+ RpnTargetAssign(context, one_ins, pos_threshold, neg_threshold,
+ rpn_batch_size, fg_num, engine, &fg_inds, &bg_inds,
+ target_label_data + i * col);
+ }
+ }
+ int* loc_index_data = loc_index->mutable_data(
+ {static_cast(fg_inds.size())}, context.GetPlace());
+ int* score_index_data = score_index->mutable_data(
+ {static_cast(fg_inds.size() + bg_inds.size())},
+ context.GetPlace());
+ memcpy(loc_index_data, reinterpret_cast(&fg_inds[0]),
+ fg_inds.size() * sizeof(int));
+ memcpy(score_index_data, reinterpret_cast(&fg_inds[0]),
+ fg_inds.size() * sizeof(int));
+ memcpy(score_index_data + fg_inds.size(),
+ reinterpret_cast(&bg_inds[0]), bg_inds.size() * sizeof(int));
+ }
+};
+
+class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddInput(
+ "DistMat",
+ "(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
+ "[K, M]. It is pair-wise distance matrix between the entities "
+ "represented by each row and each column. For example, assumed one "
+ "entity is A with shape [K], another entity is B with shape [M]. The "
+ "DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
+ "the distance is, the better macthing the pairs are. Please note, "
+ "This tensor can contain LoD information to represent a batch of "
+ "inputs. One instance of this batch can contain different numbers of "
+ "entities.");
+ AddAttr(
+ "rpn_positive_overlap",
+ "Minimum overlap required between an anchor and ground-truth "
+ "box for the (anchor, gt box) pair to be a positive example.")
+ .SetDefault(0.7);
+ AddAttr(
+ "rpn_negative_overlap",
+ "Maximum overlap allowed between an anchor and ground-truth "
+ "box for the (anchor, gt box) pair to be a negative examples.")
+ .SetDefault(0.3);
+ AddAttr(
+ "fg_fraction",
+ "Target fraction of RoI minibatch that "
+ "is labeled foreground (i.e. class > 0), 0-th class is background.")
+ .SetDefault(0.25);
+ AddAttr("rpn_batch_size_per_im",
+ "Total number of RPN examples per image.")
+ .SetDefault(256);
+ AddAttr("fix_seed",
+ "A flag indicating whether to use a fixed seed to generate "
+ "random mask. NOTE: DO NOT set this flag to true in "
+ "training. Setting this flag to true is only useful in "
+ "unittest.")
+ .SetDefault(false);
+ AddAttr("seed", "RpnTargetAssign random seed.").SetDefault(0);
+ AddOutput(
+ "LocationIndex",
+ "(Tensor), The indexes of foreground anchors in all RPN anchors, the "
+ "shape of the LocationIndex is [F], F depends on the value of input "
+ "tensor and attributes.");
+ AddOutput(
+ "ScoreIndex",
+ "(Tensor), The indexes of foreground and background anchors in all "
+ "RPN anchors(The rest anchors are ignored). The shape of the "
+ "ScoreIndex is [F + B], F and B depend on the value of input "
+ "tensor and attributes.");
+ AddOutput("TargetLabel",
+ "(Tensor), The target labels of each anchor with shape "
+ "[K * M, 1], "
+ "K and M is the same as they are in DistMat.");
+ AddComment(R"DOC(
+This operator can be, for given the IoU between the ground truth bboxes and the
+anchors, to assign classification and regression targets to each prediction.
+The Score index and LocationIndex will be generated according to the DistMat.
+The rest anchors would not contibute to the RPN training loss
+
+ScoreIndex is composed of foreground anchor indexes(positive labels) and
+background anchor indexes(negative labels). LocationIndex is exactly same
+as the foreground anchor indexes since we can not assign regression target to
+the background anchors.
+
+The classification targets(TargetLabel) is a binary class label (of being
+an object or not). Following the paper of Faster-RCNN, the positive labels
+are two kinds of anchors: (i) the anchor/anchors with the highest IoU
+overlap with a ground-truth box, or (ii) an anchor that has an IoU overlap
+higher than rpn_positive_overlap(0.7) with any ground-truth box. Note that
+a single ground-truth box may assign positive labels to multiple anchors.
+A non-positive anchor is when its IoU ratio is lower than rpn_negative_overlap
+(0.3) for all ground-truth boxes. Anchors that are neither positive nor
+negative do not contribute to the training objective.
+
+)DOC");
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OPERATOR(rpn_target_assign, ops::RpnTargetAssignOp,
+ ops::RpnTargetAssignOpMaker,
+ paddle::framework::EmptyGradOpMaker);
+REGISTER_OP_CPU_KERNEL(rpn_target_assign, ops::RpnTargetAssignKernel,
+ ops::RpnTargetAssignKernel);
diff --git a/paddle/fluid/operators/distributed/grpc_client.cc b/paddle/fluid/operators/distributed/grpc_client.cc
index 4a09f3870d64d8e14b2db41ff3ea7c2f9e67b558..35318a805898de645c844a2224f6df8c458d346c 100644
--- a/paddle/fluid/operators/distributed/grpc_client.cc
+++ b/paddle/fluid/operators/distributed/grpc_client.cc
@@ -59,7 +59,9 @@ GRPCClient::~GRPCClient() {
for (auto& it : channels_) {
it.second.reset();
}
+ channels_.clear();
}
+
client_thread_->join();
}
diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a91e0f520e93c01bc5af09b691af2d5a6deda9f2
--- /dev/null
+++ b/paddle/fluid/operators/fake_quantize_op.cc
@@ -0,0 +1,112 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fake_quantize_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+
+class FakeQuantizeOp : public framework::OperatorWithKernel {
+ public:
+ FakeQuantizeOp(const std::string &type,
+ const framework::VariableNameMap &inputs,
+ const framework::VariableNameMap &outputs,
+ const framework::AttributeMap &attrs)
+ : OperatorWithKernel(type, inputs, outputs, attrs) {}
+
+ void InferShape(framework::InferShapeContext *ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of FakeQuantizeOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of FakeQuantizeOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
+ "OutMovingScale(Out) of FakeQuantizeOp should not be null");
+ // if (ctx->HasInput("InMovingScale")) {
+ ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
+ //}
+ // if (ctx->HasInput("InScales")) {
+ PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
+ "OutScales(Out) of FakeQuantizeOp should not be null");
+ ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
+ // PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
+ // ctx->Outputs("OutScales")[0],
+ // "Mean and MeanOut should share the same memory");
+ //}
+ ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
+ ctx->ShareLoD("X", /*->*/ "Out");
+ }
+};
+
+class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddInput("X", "(Tensor) Input tensor of scale operator.");
+ AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
+ .AsDispensable();
+ AddInput("InMovingScale", "Last scale, used in static quantization.")
+ .AsDispensable();
+ AddInput("InCurrentIter",
+ "Last iteration number, used in static quantization.")
+ .AsDispensable();
+ AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
+ AddOutput("OutScales",
+ "(Tensor) scale buffer, used in static quantization.")
+ .AsDispensable();
+ AddOutput("OutMovingScale", " Current scale");
+ AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
+ AddAttr("quantize_type",
+ "(string, default abs_max)"
+ "The scaling tpe of the quantize operator.")
+ .SetDefault("abs_max");
+ AddAttr("window_size", "(int, default 10000)").SetDefault(10000);
+ AddAttr("bit_length", "(int, default 8)")
+ .SetDefault(8)
+ .AddCustomChecker([](const int &bit_length) {
+ PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
+ "'bit_length' should be between 1 and 16.");
+ });
+ AddAttr("is_test", "").SetDefault(false);
+ AddComment(R"DOC(
+FakeQuantize operator
+
+quantize_type = abs_max:
+
+ $$scale = max(abs(x))$$
+
+quantize_type = range_abs_max:
+
+ $$scale = max(max(abs(x)), history_abs_max)$$
+
+quantize_type = moving_average_abs_max:
+
+ $$scale = 0.1*scale+0.9*new_abs_max)$$
+
+$$Out = scale*X$$
+
+)DOC");
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
+ paddle::framework::EmptyGradOpMaker);
+REGISTER_OP_CPU_KERNEL(
+ fake_quantize,
+ ops::FakeQuantizeKernel,
+ ops::FakeQuantizeKernel);
diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..be0c6730a5119090600a27c66510b2a095c54583
--- /dev/null
+++ b/paddle/fluid/operators/fake_quantize_op.cu
@@ -0,0 +1,272 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include "paddle/fluid/operators/fake_quantize_op.h"
+#include "paddle/fluid/platform/cuda_primitives.h"
+
+namespace paddle {
+namespace operators {
+
+template
+__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
+ int bid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+
+ extern __shared__ T shared_max_data[];
+ if (gridDim.x > 1) {
+ shared_max_data[tid] = T(0);
+ for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
+ T tmp = fabs(in[i]);
+ if (tmp > shared_max_data[tid]) {
+ shared_max_data[tid] = tmp;
+ }
+ }
+ } else {
+ if (bid < n) {
+ shared_max_data[tid] = fabs(in[bid]);
+ } else {
+ shared_max_data[tid] = T(0);
+ }
+ }
+ __syncthreads();
+
+ for (int i = blockDim.x / 2; i > 0; i >>= 1) {
+ if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) {
+ shared_max_data[tid] = shared_max_data[tid + i];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ out[blockIdx.x] = shared_max_data[0];
+ }
+}
+
+float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array,
+ int length) {
+ float host_max;
+ int kNumTheads = 1024;
+ int gridDimx = (kNumTheads - 1 + length) / kNumTheads;
+ gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
+ framework::Tensor t;
+ float* device_max = t.mutable_data(framework::make_ddim({gridDimx}),
+ platform::CUDAPlace());
+ FindAbsMaxKernel<<>>(length, array, device_max);
+ FindAbsMaxKernel<
+ float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>(
+ gridDimx, device_max, device_max);
+ PADDLE_ENFORCE_EQ(
+ cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost),
+ cudaSuccess, "cudaMemcpy failed");
+ return host_max;
+}
+
+template
+__global__ void ApplySaturateKernel(const int n, const T* in, T* out,
+ int* num_saturate, const T min,
+ const T max) {
+ int bid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+
+ extern __shared__ int shared_count[];
+ shared_count[tid] = 0;
+ for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
+ if (in[i] > max) {
+ out[i] = max;
+ shared_count[tid] += 1;
+ } else if (in[i] < min) {
+ out[i] = min;
+ shared_count[tid] += 1;
+ } else {
+ out[i] = in[i];
+ }
+ }
+ __syncthreads();
+
+ for (int i = blockDim.x / 2; i > 0; i >>= 1) {
+ if (tid < i) {
+ shared_count[tid] += shared_count[tid + i];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ num_saturate[blockIdx.x] = shared_count[0];
+ }
+}
+
+template
+__global__ void ReduceKernel(const int n, const T* in, T* out) {
+ int tid = threadIdx.x;
+ extern __shared__ T shared_sum[];
+ if (tid < n) {
+ shared_sum[tid] = in[tid];
+ } else {
+ shared_sum[tid] = T(0);
+ }
+ __syncthreads();
+ // blockDim.x must >= n
+ for (int i = (n + 1) / 2; i > 0; i >>= 1) {
+ if (tid < i) {
+ shared_sum[tid] += shared_sum[tid + i];
+ }
+ __syncthreads();
+ }
+ if (tid == 0) {
+ out[0] = shared_sum[0];
+ }
+}
+
+template
+int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n,
+ const T* in, T* out, const T min, const T max) {
+ int host_num_saturate;
+ int kNumTheads = 1024;
+ int gridDimx = (n + kNumTheads - 1) / kNumTheads;
+ gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
+ framework::Tensor t;
+ int* device_num_saturate = t.mutable_data(
+ framework::make_ddim({gridDimx}), platform::CUDAPlace());
+ ApplySaturateKernel<
+ T><<>>(
+ n, in, out, device_num_saturate, min, max);
+ ReduceKernel<<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
+ gridDimx, device_num_saturate, device_num_saturate);
+ PADDLE_ENFORCE_EQ(cudaSuccess,
+ cudaMemcpy(&host_num_saturate, device_num_saturate,
+ sizeof(int), cudaMemcpyDeviceToHost),
+ "cudaMemcpy failed");
+ return host_num_saturate;
+}
+
+template
+class FakeQuantizeCUDAKernel : public framework::OpKernel {
+ public:
+ T FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
+ framework::Tensor* scale_list, framework::Tensor* out_scale,
+ const T& cur_scale, int window_size,
+ int current_iter) const {
+ T* sl = scale_list->mutable_data(platform::CPUPlace());
+ T remove_tmp = sl[current_iter];
+ sl[current_iter] = cur_scale;
+ T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0];
+ if (max_scale < cur_scale) {
+ max_scale = cur_scale;
+ } else if (fabs(remove_tmp - max_scale) < 1e-6) {
+ int size = (current_iter > window_size) ? window_size : current_iter;
+ max_scale = T(FindAbsMaxGpu(ctx, scale_list->data(), size));
+ }
+ return max_scale;
+ }
+
+ T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
+ framework::Tensor* out_scale,
+ const T& cur_scale) const {
+ T* ins = in_scale->mutable_data(platform::CPUPlace());
+ T* outs = out_scale->mutable_data(platform::CPUPlace());
+ outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
+ return T(outs[0]);
+ }
+
+ virtual void Compute(const framework::ExecutionContext& context) const {
+ PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
+ "This kernel only runs on GPU device.");
+ auto& device_ctx = context.cuda_device_context();
+ auto* tensor = context.Output("Out");
+ auto* in = context.Input("X");
+ const bool is_test = context.Attr("is_test");
+ tensor->mutable_data(in->place());
+ context.Output("OutMovingScale")
+ ->mutable_data(
+ context.Input("InMovingScale")->place());
+ auto quantize_type =
+ static_cast(context.Attr("quantize_type"));
+ if (quantize_type == std::string("range_abs_max")) {
+ context.Output("OutScales")
+ ->mutable_data(
+ context.Input("InScales")->place());
+ context.Output("OutCurrentIter")
+ ->mutable_data(
+ context.Input("InCurrentIter")->place());
+ }
+
+ T scale = T(1);
+ int window_size = context.Attr("window_size");
+ T bin_cnt = (T)((1 << (context.Attr("bit_length") - 1)) - 1);
+ if (quantize_type == std::string("abs_max")) {
+ auto* saving_scale = context.Output("OutMovingScale");
+ scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel());
+ saving_scale->mutable_data(platform::CPUPlace())[0] = scale;
+
+ auto& device_ctx = context.template device_context();
+ auto* scale_list = context.Output("OutScales");
+ math::SetConstant scalar;
+ scale_list->mutable_data(context.GetPlace());
+ scalar(device_ctx, scale_list, static_cast(0));
+ auto* iter = context.Output("OutCurrentIter");
+ iter->mutable_data(context.GetPlace());
+ scalar(device_ctx, iter, static_cast(0));
+ } else if (quantize_type == std::string("range_abs_max")) {
+ auto* moving_scale = const_cast(
+ context.Input("InMovingScale"));
+ if (is_test) {
+ scale = moving_scale->mutable_data(platform::CPUPlace())[0];
+ } else {
+ auto* it = const_cast(
+ context.Input("InCurrentIter"));
+ auto* iter = context.Output("OutCurrentIter");
+ int* last_iter = it->mutable_data