diff --git a/doc/v2/howto/capi/compile_paddle_lib_en.md b/doc/v2/howto/capi/compile_paddle_lib_en.md
index 11d69b9b79c1a41898d3060d3fe25a31330334a3..6212a3081116d988630706e83d2349dd200b73ab 100644
--- a/doc/v2/howto/capi/compile_paddle_lib_en.md
+++ b/doc/v2/howto/capi/compile_paddle_lib_en.md
@@ -1,3 +1,175 @@
## Install and Build
-TBD
+### Download & Install
+
+ Download the latest C-API development package from CI system and install. You can find the required version in the table below:
+
+
+### From source
+
+ Users can also compile the C-API library from PaddlePaddle source code by compiling with the following compilation options:
+
+
+
+
+Options |
+Value |
+
+
+
+
+WITH_C_API |
+ON |
+
+
+WITH_PYTHON |
+OFF(recommended) |
+
+
+WITH_SWIG_PY |
+OFF(recommended) |
+
+
+WITH_GOLANG |
+OFF(recommended) |
+
+
+WITH_GPU |
+ON/OFF |
+
+
+WITH_MKL |
+ON/OFF |
+
+
+It is best to set up with recommended values to avoid linking with unnecessary libraries. Set other compilation options as you need.
+
+Pull the latest following code snippet from github, and configure compilation options(replace PADDLE_ROOT with the installation path of the PaddlePaddle C-API inference library):
+
+```shell
+PADDLE_ROOT=/path/of/capi
+git clone https://github.com/PaddlePaddle/Paddle.git
+cd Paddle
+mkdir build
+cd build
+cmake -DCMAKE_INSTALL_PREFIX=$PADDLE_ROOT \
+ -DCMAKE_BUILD_TYPE=Release \
+ -DWITH_C_API=ON \
+ -DWITH_SWIG_PY=OFF \
+ -DWITH_GOLANG=OFF \
+ -DWITH_PYTHON=OFF \
+ -DWITH_MKL=OFF \
+ -DWITH_GPU=OFF \
+ ..
+```
+
+After running the above code to generate Makefile , run: `make && make install`. After successful compilation, the dependencies required by C-API(includes: (1)PaddlePaddle inference library and header files; (2) Third-party libraries and header files) will be stored in the `PADDLE_ROOT` directory.
+
+If the compilation is successful, see the following directory structure under `PADDLE_ROOT`(includes PaddlePaddle header files and libraries, and third-party libraries and header files(determined by the link methods if necessary)):
+
+```text
+├── include
+│ └── paddle
+│ ├── arguments.h
+│ ├── capi.h
+│ ├── capi_private.h
+│ ├── config.h
+│ ├── error.h
+│ ├── gradient_machine.h
+│ ├── main.h
+│ ├── matrix.h
+│ ├── paddle_capi.map
+│ └── vector.h
+├── lib
+│ ├── libpaddle_capi_engine.a
+│ ├── libpaddle_capi_layers.a
+│ ├── libpaddle_capi_shared.so
+│ └── libpaddle_capi_whole.a
+└── third_party
+ ├── gflags
+ │ ├── include
+ │ │ └── gflags
+ │ │ ├── gflags_completions.h
+ │ │ ├── gflags_declare.h
+ │ │ ...
+ │ └── lib
+ │ └── libgflags.a
+ ├── glog
+ │ ├── include
+ │ │ └── glog
+ │ │ ├── config.h
+ │ │ ...
+ │ └── lib
+ │ └── libglog.a
+ ├── openblas
+ │ ├── include
+ │ │ ├── cblas.h
+ │ │ ...
+ │ └── lib
+ │ ...
+ ├── protobuf
+ │ ├── include
+ │ │ └── google
+ │ │ └── protobuf
+ │ │ ...
+ │ └── lib
+ │ └── libprotobuf-lite.a
+ └── zlib
+ ├── include
+ │ ...
+ └── lib
+ ...
+
+```
+
+### Linking Description:
+
+There are three kinds of linking methods:
+
+1. Linking with dynamic library `libpaddle_capi_shared.so`(This way is much more convenient and easier, **Without special requirements, it is recommended**), refer to the following:
+ 1. Compiling with CPU version and using `OpenBLAS`; only need to link one library named `libpaddle_capi_shared.so` to develop prediction program through C-API.
+ 1. Compiling with CPU version and using `MKL` lib, you need to link MKL library directly to develop prediction program through PaddlePaddle C-API, due to `MKL` has its own dynamic library.
+ 1. Compiling with GPU version, CUDA library will be loaded dynamically on prediction program run-time, and also set CUDA library to `LD_LIBRARY_PATH` environment variable.
+
+2. Linking with static library `libpaddle_capi_whole.a`,refer to the following:
+ 1. Specify `-Wl,--whole-archive` linking options.
+ 1. Explicitly link third-party libraries such as `gflags`、`glog`、`libz`、`protobuf` .etc, you can find them under `PADDLE_ROOT/third_party` directory.
+ 1. Use OpenBLAS library if compiling C-API,must explicitly link `libopenblas.a`.
+ 1. Use MKL when compiling C-API, must explicitly link MKL dynamic library.
+
+3. Linking with static library `libpaddle_capi_layers.a` and `libpaddle_capi_engine.a`,refer to the following:
+ 1. This linking methods is mainly used for mobile prediction.
+ 1. Split `libpaddle_capi_whole.a` into two static linking library at least to reduce the size of linking libraries.
+ 1. Specify `-Wl,--whole-archive -lpaddle_capi_layers` and `-Wl,--no-whole-archive -lpaddle_capi_engine` for linking.
+ 1. The third-party dependencies need explicitly link same as method 2 above.
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 74945fb4f2f745b6ca9c48adb0c8b9e6ae1e94a4..f393105fe82bfad70246952deada8e296c851ef5 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -174,12 +174,17 @@ void ParallelExecutor::SplitTensorToPlaces(
const std::unordered_map &feed_tensors) {
for (auto it : feed_tensors) {
auto lod_tensors = it.second.SplitLoDTensor(member_->places_);
+ PADDLE_ENFORCE_EQ(
+ member_->places_.size(), lod_tensors.size(),
+ "The number of samples of current batch is less than the count of "
+ "devices, currently, it is not allowed. (%d vs %d)",
+ member_->places_.size(), lod_tensors.size());
for (size_t j = 0; j < member_->places_.size(); ++j) {
// TODO(panxy0718): Do I need to delete this var?
- member_->local_scopes_[j]
- ->Var(it.first)
- ->GetMutable()
- ->ShareDataWith(lod_tensors[j]);
+ auto t =
+ member_->local_scopes_[j]->Var(it.first)->GetMutable();
+ t->ShareDataWith(lod_tensors[j]);
+ t->set_lod(lod_tensors[j].lod());
}
}
}
diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc
index 36049ee6a4a0d2a251b6d10cf1ff05a9d9845089..c9939e8602ed341d37784ca292a55326899e8e65 100644
--- a/paddle/fluid/operators/batch_norm_op.cc
+++ b/paddle/fluid/operators/batch_norm_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
+#include
#include "paddle/fluid/framework/data_layout.h"
namespace paddle {
diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc
index 6ceacc39924a7558e380aaf563aaf234f1bf30a5..eecb58e11ef57b550c79c040e6933ed6e52e2e87 100644
--- a/paddle/fluid/operators/batch_norm_op.cu.cc
+++ b/paddle/fluid/operators/batch_norm_op.cu.cc
@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
-#include "paddle/fluid/framework/data_layout.h"
-
#include
+#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
diff --git a/paddle/fluid/operators/batch_size_like.h b/paddle/fluid/operators/batch_size_like.h
index 0bdf27e620a3a7c7b62b955f708a5e2aad1a6986..dd51a11fbe6ad5e528197b67536518c4b31fa355 100644
--- a/paddle/fluid/operators/batch_size_like.h
+++ b/paddle/fluid/operators/batch_size_like.h
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-
+#include
+#include
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
diff --git a/paddle/fluid/operators/box_coder_op.h b/paddle/fluid/operators/box_coder_op.h
index 3c7cac1cd17042994287effc31a918ebd4353c4c..77fc6c2b62af42e6526b889aeef2d9bab795baec 100644
--- a/paddle/fluid/operators/box_coder_op.h
+++ b/paddle/fluid/operators/box_coder_op.h
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
diff --git a/paddle/fluid/operators/compare_op.cc b/paddle/fluid/operators/compare_op.cc
index 9a139ab27ec53395a8d1ab1347dbce93ea68fd8e..3a6a357e81949014a70e5bae1ee0e1c8b9d0c2ce 100644
--- a/paddle/fluid/operators/compare_op.cc
+++ b/paddle/fluid/operators/compare_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/compare_op.h"
+#include
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc
index 0eedd8ee51ebfff6f553d8e19e97c3a45a95fa6a..d65a7b34678cda38d5f8beb9154d61928f517ce0 100644
--- a/paddle/fluid/operators/concat_op.cc
+++ b/paddle/fluid/operators/concat_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
+#include
#include
namespace paddle {
diff --git a/paddle/fluid/operators/cond_op.h b/paddle/fluid/operators/cond_op.h
index a04fae2182005d4eb08305e943449977bfb637f9..d3888923dbdeee122fb3045a839c0ba639b892b1 100644
--- a/paddle/fluid/operators/cond_op.h
+++ b/paddle/fluid/operators/cond_op.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include
#include
#include "glog/logging.h"
#include "paddle/fluid/framework/ddim.h"
diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc
index b2a3cfc89f18eff24c941c664b1184b4485ab895..08f5939d42a41d235a94eff16cf2f558068d6aaa 100644
--- a/paddle/fluid/operators/conv_transpose_op.cc
+++ b/paddle/fluid/operators/conv_transpose_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h"
+#include
+#include
namespace paddle {
namespace operators {
diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h
index d4e4b641ece9ed120904ded6f8baed65a2666213..bfc0177c2a0da1627fbca532764fdae8167b6b2a 100644
--- a/paddle/fluid/operators/conv_transpose_op.h
+++ b/paddle/fluid/operators/conv_transpose_op.h
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-
+#include
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/im2col.h"
diff --git a/paddle/fluid/operators/crf_decoding_op.h b/paddle/fluid/operators/crf_decoding_op.h
index 2b2a733fb9f162755e5c548fec617937d86689dd..3f5fab3b382bea97f43e4bc1b2cd436c956ba264 100644
--- a/paddle/fluid/operators/crf_decoding_op.h
+++ b/paddle/fluid/operators/crf_decoding_op.h
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
diff --git a/paddle/fluid/operators/crop_op.h b/paddle/fluid/operators/crop_op.h
index c5ac6849789587f2f41588f79bd538f7b79a7478..f05c2e23284e3a24cf48442996f671ec6084c391 100644
--- a/paddle/fluid/operators/crop_op.h
+++ b/paddle/fluid/operators/crop_op.h
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-
+#include
+#include
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/strided_memcpy.h"
diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu
index 82e12943148a806bae719c722944d6a9d5236b7c..e53183603fec54ceef68873cfd97b4b985b0d437 100644
--- a/paddle/fluid/operators/math/math_function.cu
+++ b/paddle/fluid/operators/math/math_function.cu
@@ -39,13 +39,14 @@ void gemm(
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
- float h_alpha = static_cast(alpha);
- float h_beta = static_cast(beta);
-
// TODO(kexinzhao): add processing code for compute capability < 53 case
PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53,
"cublas fp16 gemm requires GPU compute capability >= 53");
+#if CUDA_VERSION >= 8000
+ float h_alpha = static_cast(alpha);
+ float h_beta = static_cast(beta);
+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
#if CUDA_VERSION >= 9000
if (context.GetComputeCapability() >= 70) {
@@ -56,7 +57,7 @@ void gemm(
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(),
CUBLAS_DEFAULT_MATH));
}
-#endif
+#endif // CUDA_VERSION >= 9000
// cublasHgemm does true FP16 computation which is slow for non-Volta
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
@@ -66,6 +67,18 @@ void gemm(
context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
CUDA_R_32F, algo));
+#else
+ // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
+ const half h_alpha = static_cast(alpha);
+ const half h_beta = static_cast(beta);
+ const half* h_A = reinterpret_cast(A);
+ const half* h_B = reinterpret_cast(B);
+ half* h_C = reinterpret_cast(C);
+
+ PADDLE_ENFORCE(platform::dynload::cublasHgemm(
+ context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb,
+ h_A, lda, &h_beta, h_C, N));
+#endif // CUDA_VERSION >= 8000
}
template <>
diff --git a/paddle/fluid/platform/dynload/cublas.cc b/paddle/fluid/platform/dynload/cublas.cc
index eb541579a136de2a84ecc9773e0c312b405f7e86..361d3439b844e9f68d3fba0a0e41ec457118a4a9 100644
--- a/paddle/fluid/platform/dynload/cublas.cc
+++ b/paddle/fluid/platform/dynload/cublas.cc
@@ -28,6 +28,10 @@ CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
CUBLAS_BLAS_ROUTINE_EACH_R2(DEFINE_WRAP);
#endif
+#ifdef CUBLAS_BLAS_ROUTINE_EACH_R3
+CUBLAS_BLAS_ROUTINE_EACH_R3(DEFINE_WRAP);
+#endif
+
} // namespace dynload
} // namespace platform
} // namespace paddle
diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h
index a41018d350e89881888d5e31089c2b9ecd76f6c0..1ab55d6b9bf8fdbd14c9c2bd978e3e99dba3e73e 100644
--- a/paddle/fluid/platform/dynload/cublas.h
+++ b/paddle/fluid/platform/dynload/cublas.h
@@ -71,7 +71,6 @@ extern void *cublas_dso_handle;
__macro(cublasDgemm_v2); \
__macro(cublasHgemm); \
__macro(cublasSgemmEx); \
- __macro(cublasGemmEx); \
__macro(cublasSgeam_v2); \
__macro(cublasDgeam_v2); \
__macro(cublasCreate_v2); \
@@ -83,11 +82,6 @@ extern void *cublas_dso_handle;
__macro(cublasDgemmBatched); \
__macro(cublasCgemmBatched); \
__macro(cublasZgemmBatched); \
- __macro(cublasSgemmStridedBatched); \
- __macro(cublasDgemmStridedBatched); \
- __macro(cublasCgemmStridedBatched); \
- __macro(cublasZgemmStridedBatched); \
- __macro(cublasHgemmStridedBatched); \
__macro(cublasSgetrfBatched); \
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
@@ -95,10 +89,24 @@ extern void *cublas_dso_handle;
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
+// APIs available after CUDA 8.0
+#if CUDA_VERSION >= 8000
+#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) \
+ __macro(cublasGemmEx); \
+ __macro(cublasSgemmStridedBatched); \
+ __macro(cublasDgemmStridedBatched); \
+ __macro(cublasCgemmStridedBatched); \
+ __macro(cublasZgemmStridedBatched); \
+ __macro(cublasHgemmStridedBatched);
+
+CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
+#endif
+
// APIs available after CUDA 9.0
#if CUDA_VERSION >= 9000
-#define CUBLAS_BLAS_ROUTINE_EACH_R2(__macro) __macro(cublasSetMathMode);
-CUBLAS_BLAS_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
+#define CUBLAS_BLAS_ROUTINE_EACH_R3(__macro) __macro(cublasSetMathMode);
+
+CUBLAS_BLAS_ROUTINE_EACH_R3(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP)
#endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index c709f364c12a8260f6e161f475d799d709b5eac3..5ce2aa1fc4d0b275b502af0f97e4a0f83e85de5b 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -87,7 +87,8 @@ class ParallelExecutor(object):
# performance. Worth tunning for other models in the future.
num_threads = len(self._places)
else:
- min(len(self._places) * 2, multiprocessing.cpu_count())
+ num_threads = min(
+ len(self._places) * 2, multiprocessing.cpu_count())
main = main_program
main = main if main else framework.default_main_program()