diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bbe829ac16180088bfa37df66e23f19b021ea3..030bd19b3fd2f561a847bbc4613e5d2030812a92 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
-find_package(Sphinx)
if(NOT CMAKE_CROSSCOMPILING)
find_package(CUDA QUIET)
endif(NOT CMAKE_CROSSCOMPILING)
@@ -226,5 +225,7 @@ if(WITH_PYTHON)
endif()
if(WITH_DOC)
+ find_package(Sphinx REQUIRED)
+ find_python_module(recommonmark REQUIRED)
add_subdirectory(doc)
endif()
diff --git a/doc/fluid/design/dist_train/async_update.md b/doc/fluid/design/dist_train/async_update.md
index 6a0835b761b69030ba30697e6e8863928efbf57f..248d2ec18dafdecac9184527638754b6ba4d85b8 100644
--- a/doc/fluid/design/dist_train/async_update.md
+++ b/doc/fluid/design/dist_train/async_update.md
@@ -4,34 +4,37 @@
For the typical synchronous distributed training, some significant steps are as follows:
-1. A Trainer will compute the gradients and SEND them to the Parameter Server(PServer) nodes.
-1. After the PServer node received gradients came from all the Trainers, It will aggregate the
+1. A trainer process will compute the gradients and **send** them to the parameter server (PS) nodes.
+1. After the PS node received gradients came from all the Trainers, It will aggregate the
gradient variables for the same parameter into one gradient variable and then apply the aggregated
gradient to the respective parameter, finally using an optimize algorithms(SGD, Monument...)
to update the parameters.
-1. The Trainer would wait for the PServers finished the optimize stage, and GET the parameters from PServer,
+1. The Trainer would wait for the PS finished the optimize stage, and GET the parameters from PS,
so all the Trainers would get the same parameters.
-In the synchronously distributed training, there should be a `Barrier` to synchronise the
-parameters after the optimizing stage. The performance of a distributed training job would
-depend on the slowest node if there were hundreds or thousands of training nodes in a
-Job, the performance of synchronously distributed training might be very poor because of
-the slow node. So this design doc would introduce an approach to implement
-*asynchronously* distributed training in PaddlePaddle Fluid.
+In Synchronous Distributed Training, there is a **barrier** on each PS to wait until all trainers processes
+have completed running current mini-batch. After that, all trainers can continue to run the next
+mini-batch. So, we can find that the overall performance of Synchronous Distributed Training depends
+on the slowest node.
+
+In Asynchronous Distributed Training, we don't need to wait for a global mini-bach, the optimizer on
+the PS will run immediately when the gradient is uploaded to the PS from one trainer. This mode would
+train such models that achieve scaling, better throughput. In this design doc, we will introduce how to
+implement the Asynchronous Distributed Training base on PaddlePaddle Fluid.
## Design
-As the figure above, we describe a global view of asynchronously update process and use
+As the figure above, we describe a global view of the asynchronous update process and use
the parameter `w1` as an example to introduce the steps:
1. For each gradient variables, they may distribute on different GPU card and aggregate
them while they are all calculated.
-1. Split the gradient variable into multiple blocks according to the number of PServer
+1. Split the gradient variable into multiple blocks according to the number of PS
instances and then send them.
-1. PServer would run an `Optimize Block` using a specified optimize algorithm to update
+1. PS would run an `Optimize Block` using a specified optimize algorithm to update
the specified parameter.
-1. The trainer will fetch latest parameter from PServer before running forward Op which depends
+1. The trainer will fetch the latest parameter from PS before running forward Op which depends
on the specified parameter.
1. Broadcast the received variable into multiple GPU cards and continue to run the next
mini-batch.
@@ -40,8 +43,8 @@ mini-batch.
- For the multiple devices distributed training, we need to aggregate the gradient
variables which placed on different devices firstly and then schedule a `SendVars` Operator to
-send the gradient variables to the multiple PServer instances.
-- Schedule `FetchVars` operator to fetch the latest parameter from PServer before running
+send the gradient variables to the multiple PS instances.
+- Schedule `FetchVars` operator to fetch the latest parameter from PS before running
the forward ops.
- There could be a large number of gradient variables to be sent, so we need to use another
thread pool(IO Threadpool) whose a number of the schedulable threads is larger than the
diff --git a/doc/v2/build_and_install/build_from_source_cn.rst b/doc/v2/build_and_install/build_from_source_cn.rst
index 115b92a33888abf1e1be400e1abbb58b632a2976..f846928954dd3a05e11054ce2ff2ff839fbefd4b 100644
--- a/doc/v2/build_and_install/build_from_source_cn.rst
+++ b/doc/v2/build_and_install/build_from_source_cn.rst
@@ -19,8 +19,9 @@
----------------
PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像
-可以在 `这里 `_ 找到。或者
-参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
+可以在 `这里 `_ 找到,您也可以
+在 `这里 `_ 找到 paddle_manylinux_devel
+镜像的编译以及使用方法。或者参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。
diff --git a/doc/v2/build_and_install/build_from_source_en.rst b/doc/v2/build_and_install/build_from_source_en.rst
index 8fef9e7347e8d924026999bfda985381750c6b51..d1b5b88dff81d4c5cee3dd13a7dccbc333ab6a17 100644
--- a/doc/v2/build_and_install/build_from_source_en.rst
+++ b/doc/v2/build_and_install/build_from_source_en.rst
@@ -22,6 +22,8 @@ How To Build
You need to use Docker to build PaddlePaddle
to avoid installing dependencies by yourself. We have several pre-built
Docker images `here `_ ,
+you can also find how to build and use paddle_manylinux_devel Docker image from
+`here `_
Or you can build your own image from source as the optional step below:
.. code-block:: bash
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index ab71e0e63ce18e4f221a046eeb2c39499c1c3816..ed1e70c6460b513c1d2e1add18ac037f71d36944 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
-
+cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
if(WITH_GPU)
- nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto)
+ nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type)
else()
- cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto)
+ cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type)
endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b9c90cb0c32f337ba82ce1eaa5b43199540491ef
--- /dev/null
+++ b/paddle/fluid/framework/data_type.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/data_type.h"
+#include
+#include
+#include
+
+namespace paddle {
+namespace framework {
+
+struct DataTypeMap {
+ std::unordered_map cpp_to_proto_;
+ std::unordered_map proto_to_cpp_;
+ std::unordered_map proto_to_str_;
+ std::unordered_map cpp_to_size_;
+};
+
+static DataTypeMap* InitDataTypeMap();
+static DataTypeMap& gDataTypeMap() {
+ static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
+ return *g_data_type_map_;
+}
+
+template
+static inline void RegisterType(DataTypeMap* map,
+ proto::VarType::Type proto_type,
+ const std::string& name) {
+ map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T));
+ map->cpp_to_proto_.emplace(typeid(T), proto_type);
+ map->proto_to_str_.emplace(static_cast(proto_type), name);
+ map->cpp_to_size_.emplace(typeid(T), sizeof(T));
+}
+
+static DataTypeMap* InitDataTypeMap() {
+ auto retv = new DataTypeMap();
+
+#define RegType(cc_type, proto_type) \
+ RegisterType(retv, proto_type, #cc_type)
+
+ // NOTE: Add your customize type here.
+ RegType(platform::float16, proto::VarType::FP16);
+ RegType(float, proto::VarType::FP32);
+ RegType(double, proto::VarType::FP64);
+ RegType(int, proto::VarType::INT32);
+ RegType(int64_t, proto::VarType::INT64);
+ RegType(bool, proto::VarType::BOOL);
+ RegType(size_t, proto::VarType::SIZE_T);
+ RegType(int16_t, proto::VarType::INT16);
+
+#undef RegType
+ return retv;
+}
+
+proto::VarType::Type ToDataType(std::type_index type) {
+ auto it = gDataTypeMap().cpp_to_proto_.find(type);
+ if (it != gDataTypeMap().cpp_to_proto_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support %s as tensor type", type.name());
+}
+
+std::type_index ToTypeIndex(proto::VarType::Type type) {
+ auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type));
+ if (it != gDataTypeMap().proto_to_cpp_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
+ static_cast(type));
+}
+
+std::string DataTypeToString(const proto::VarType::Type type) {
+ auto it = gDataTypeMap().proto_to_str_.find(static_cast(type));
+ if (it != gDataTypeMap().proto_to_str_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
+ static_cast(type));
+}
+
+size_t SizeOfType(std::type_index type) {
+ auto it = gDataTypeMap().cpp_to_size_.find(type);
+ if (it != gDataTypeMap().cpp_to_size_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support %s as tensor type", type.name());
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h
index 2a528eb3aa562568c92059250f2c9bc5a75ec103..4b9f572ec5f1cda71c8b8dd8fae54b42e9f16f7a 100644
--- a/paddle/fluid/framework/data_type.h
+++ b/paddle/fluid/framework/data_type.h
@@ -17,51 +17,14 @@ limitations under the License. */
#include
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
+
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
-inline proto::VarType::Type ToDataType(std::type_index type) {
- if (typeid(platform::float16).hash_code() == type.hash_code()) {
- return proto::VarType::FP16;
- } else if (typeid(const float).hash_code() == type.hash_code()) {
- // CPPLint complains Using C-style cast. Use static_cast() instead
- // One fix to this is to replace float with const float because
- // typeid(T) == typeid(const T)
- // http://en.cppreference.com/w/cpp/language/typeid
- return proto::VarType::FP32;
- } else if (typeid(const double).hash_code() == type.hash_code()) {
- return proto::VarType::FP64;
- } else if (typeid(const int).hash_code() == type.hash_code()) {
- return proto::VarType::INT32;
- } else if (typeid(const int64_t).hash_code() == type.hash_code()) {
- return proto::VarType::INT64;
- } else if (typeid(const bool).hash_code() == type.hash_code()) {
- return proto::VarType::BOOL;
- } else {
- PADDLE_THROW("Not supported");
- }
-}
-
-inline std::type_index ToTypeIndex(proto::VarType::Type type) {
- switch (type) {
- case proto::VarType::FP16:
- return typeid(platform::float16);
- case proto::VarType::FP32:
- return typeid(float);
- case proto::VarType::FP64:
- return typeid(double);
- case proto::VarType::INT32:
- return typeid(int);
- case proto::VarType::INT64:
- return typeid(int64_t);
- case proto::VarType::BOOL:
- return typeid(bool);
- default:
- PADDLE_THROW("Not support type %d", type);
- }
-}
+extern proto::VarType::Type ToDataType(std::type_index type);
+extern std::type_index ToTypeIndex(proto::VarType::Type type);
template
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
@@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
}
}
-inline std::string DataTypeToString(const proto::VarType::Type type) {
- switch (type) {
- case proto::VarType::FP16:
- return "float16";
- case proto::VarType::FP32:
- return "float32";
- case proto::VarType::FP64:
- return "float64";
- case proto::VarType::INT16:
- return "int16";
- case proto::VarType::INT32:
- return "int32";
- case proto::VarType::INT64:
- return "int64";
- case proto::VarType::BOOL:
- return "bool";
- default:
- PADDLE_THROW("Not support type %d", type);
- }
-}
-
+extern std::string DataTypeToString(const proto::VarType::Type type);
+extern size_t SizeOfType(std::type_index type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}
-
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto
index 96f53dc1bc8747e1b8ea84166614f98ff363ae5e..d2558f111f49139b33f921f7260b41830279edc8 100644
--- a/paddle/fluid/framework/framework.proto
+++ b/paddle/fluid/framework/framework.proto
@@ -101,6 +101,8 @@ message VarType {
FP16 = 4;
FP32 = 5;
FP64 = 6;
+ // Tensor is used in C++.
+ SIZE_T = 19;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc
index d37ce149ce3df63692b41289bb03448d54e392f5..db95861c510b52a5b52229541434e6437d3fb9f4 100644
--- a/paddle/fluid/framework/op_kernel_type_test.cc
+++ b/paddle/fluid/framework/op_kernel_type_test.cc
@@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) {
LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
- "data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type["
+ "data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
"CUDNN]");
}
diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h
index d373c48b1a75c5f75c7520b56f230bc2c146b174..a4eb6f706edab9479cbce436311eb96da8845646 100644
--- a/paddle/fluid/framework/operator.h
+++ b/paddle/fluid/framework/operator.h
@@ -192,6 +192,10 @@ class ExecutionContext {
return op_.Attr(name);
}
+ bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
+
+ bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
+
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 20ef7e09f630140c44774147aa727780df6333fa..95e807c0afa45bc4f4feb84d450b2d0584bc3b28 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector &local_scopes, bool allow_op_delay,
- bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
+ bool use_default_grad_scale, bool balance_parameter_opt_between_cards,
+ size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
@@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
- member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
+ auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
+ ncclUniqueId *nccl_id = nullptr;
+ if (nccl_id_var != nullptr) {
+ nccl_id = nccl_id_var->GetMutable();
+ }
+ member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
+ member_->places_, nccl_id, num_trainers, trainer_id));
#endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index b251fc91417a1c00e61e9c3c952460e6268d2819..9e279876cfeef20a1921f8bd1c27046a477b9f56 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -41,7 +41,8 @@ class ParallelExecutor {
const std::string& loss_var_name, Scope* scope,
const std::vector& local_scopes,
bool allow_op_delay, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards);
+ bool balance_parameter_opt_between_cards,
+ size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor();
diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h
index f49d1a47a325b2aac6185073203df124be18b54d..0a1db7758bd9ec0dac133efcbf495de1d690021d 100644
--- a/paddle/fluid/framework/tensor_impl.h
+++ b/paddle/fluid/framework/tensor_impl.h
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
-
-template
-struct SizeOfTypeFunctor;
-
-template
-struct SizeOfTypeFunctor {
- size_t operator()(std::type_index type) const {
- if (typeid(T).hash_code() == type.hash_code()) {
- return sizeof(T);
- } else {
- return 0UL;
- }
- }
-};
-
-template <>
-struct SizeOfTypeFunctor<> {
- size_t operator()(std::type_index type) const { return 0UL; }
-};
-
-template
-struct SizeOfTypeFunctor {
- size_t operator()(std::type_index type) const {
- SizeOfTypeFunctor head;
- size_t head_size = head(type);
- if (head_size != 0) {
- return head_size;
- }
- SizeOfTypeFunctor tail;
- return tail(type);
- }
-};
-
-static inline size_t SizeOfType(std::type_index type) {
- SizeOfTypeFunctor
- functor;
- size_t size = functor(type);
- PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
- return size;
-}
-
+extern size_t SizeOfType(std::type_index type);
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h
index 3359987874f2d74d7e4646baa38790431c4b28fd..4bf1840fdda8508b52d7274a338c5b1c95baf354 100644
--- a/paddle/fluid/inference/analysis/dot.h
+++ b/paddle/fluid/inference/analysis/dot.h
@@ -21,6 +21,7 @@
#include
#include
+#include
#include
#include
diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h
index de0375551e16ec53b90414c7446234fda98bf706..ce2b8161715a3fa2278ce950dbac82c6d0042bef 100644
--- a/paddle/fluid/inference/engine.h
+++ b/paddle/fluid/inference/engine.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
@@ -58,8 +59,8 @@ class EngineBase {
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
- int max_size; // buffer allocated space.
- int size; // data size.
+ size_t max_size; // buffer allocated space.
+ size_t size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt
index 677b3e04af8e7f5662a15fb32e3b03f45d262733..b52d083f280e5e7713600a7b748dedd37aca0a1e 100644
--- a/paddle/fluid/inference/tensorrt/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt
@@ -1,5 +1,4 @@
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
-
add_subdirectory(convert)
diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
index 286abf736e8ff8a357482419e85ad1258a6c6acd..4fb4511d99179e4ea14cde66feb13bc9e114581a 100644
--- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
@@ -1,4 +1,4 @@
-nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc op_converter.h DEPS ${FLUID_CORE_MODULES})
-nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
+nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
+nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
index 32e8631fde3f748669d2008b4a060455a37e154e..854f434d93e81237dc85c5df62debcf3b3824b78 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
@@ -23,26 +23,42 @@ namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
-class DefaultInputConverter : public EngineInputConverter {
+class DefaultIOConverter : public EngineIOConverter {
public:
- DefaultInputConverter() {}
+ DefaultIOConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
- PADDLE_ENFORCE_LE(in.memory_size(), max_size);
+ PADDLE_ENFORCE(stream_ != nullptr);
const auto& place = in.place();
+ size_t size = in.memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
if (is_cpu_place(place)) {
- PADDLE_ENFORCE(stream_ != nullptr);
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToDevice, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToHost, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyDeviceToDevice, *stream_));
+ } else {
+ PADDLE_THROW("Unknown device for converter");
+ }
+ cudaStreamSynchronize(*stream_);
+ }
+ // NOTE in is GPU memory.
+ virtual void operator()(const void* in, LoDTensor* out,
+ size_t max_size) override {
+ PADDLE_ENFORCE(in != nullptr);
+ PADDLE_ENFORCE(stream_ != nullptr);
+ const auto& place = out->place();
+ size_t size = out->memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
+ if (is_cpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToHost, *stream_));
+ } else if (is_gpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToDevice, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
@@ -50,7 +66,8 @@ class DefaultInputConverter : public EngineInputConverter {
}
};
-REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
+// fluid LodTensor <-> tensorrt ITensor
+REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter);
} // namespace tensorrt
} // namespace inference
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h
index 8972dae92be2c2d261a13c48d98e675f64e51d31..71c48e085d25d2bc6720d93735f661f9e3af7b40 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
@@ -25,43 +26,57 @@ namespace tensorrt {
using framework::LoDTensor;
/*
- * Convert Input from Fluid to an Engine.
- * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
- * most cases just need to copy the data.
+ * Convert Input from Fluid to TensorRT Engine.
+ * Convert Output from TensorRT Engine to Fluid.
+ *
+ * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row
+ * major,
+ * so in the default case just need to copy the data.
*/
-class EngineInputConverter {
+class EngineIOConverter {
public:
- EngineInputConverter() {}
+ EngineIOConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
+ virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
- static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
- size_t max_size, cudaStream_t* stream) {
+ static void ConvertInput(const std::string& op_type, const LoDTensor& in,
+ void* out, size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
- auto* converter = Registry::Lookup(
- in_op_type, "default" /* default_type */);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
- virtual ~EngineInputConverter() {}
+ static void ConvertOutput(const std::string& op_type, const void* in,
+ LoDTensor* out, size_t max_size,
+ cudaStream_t* stream) {
+ PADDLE_ENFORCE(stream != nullptr);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
+ PADDLE_ENFORCE_NOT_NULL(converter);
+ converter->SetStream(stream);
+ (*converter)(in, out, max_size);
+ }
+
+ virtual ~EngineIOConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
+#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \
+ struct trt_io_##op_type__##_converter { \
+ trt_io_##op_type__##_converter() { \
+ Registry::Register(#op_type__); \
+ } \
+ }; \
+ trt_io_##op_type__##_converter trt_io_##op_type__##_converter__;
+
} // namespace tensorrt
} // namespace inference
} // namespace paddle
-
-#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
- struct trt_input_##in_op_type__##_converter { \
- trt_input_##in_op_type__##_converter() { \
- ::paddle::inference::Registry::Register< \
- Converter__>(#in_op_type__); \
- } \
- }; \
- trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
index 669fba1eb81c5caacea039522ea70a2d0523d022..ec33f97c8240dfc09a203d68599bffe78a4abb12 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
@@ -26,7 +27,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
-void Compare(float input, float expect) {
+void Compare(const std::string op_type, float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
@@ -35,6 +36,7 @@ void Compare(float input, float expect) {
auto x_var = scope.Var("X");
auto x_tensor = x_var->GetMutable();
x_tensor->Resize({1, 1});
+ x_tensor->mutable_data(place);
std::vector init;
init.push_back(input);
framework::TensorFromVector(init, ctx, x_tensor);
@@ -45,14 +47,15 @@ void Compare(float input, float expect) {
out_tensor->mutable_data(place);
framework::OpDesc op_desc;
- op_desc.SetType("relu");
+ op_desc.SetType(op_type);
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
- auto relu_op = framework::OpRegistry::CreateOp(*op_desc.Proto());
+ auto op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
- relu_op->Run(scope, place);
+ op->Run(scope, place);
+ // get fluid output
std::vector out1;
framework::TensorToVector(*out_tensor, ctx, &out1);
@@ -63,21 +66,28 @@ void Compare(float input, float expect) {
engine->InitNetwork();
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
-
+ // convert op
OpConverter op_converter;
op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
- engine->SetInputFromCPU("X", &input, 1 * sizeof(float));
- // run tensorrt op
+ // convert LoDTensor to ITensor
+ size_t size = x_tensor->memory_size();
+ EngineIOConverter::ConvertInput(op_type, *x_tensor,
+ engine->buffer("X").buffer, size, &stream);
+ // run tensorrt Outp
engine->Execute(1);
-
- float out2;
- engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float));
-
- ASSERT_EQ(out1[0], out2);
+ // convert ITensor to LoDTensor
+ EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer,
+ out_tensor, size, &stream);
+ // get tensorrt output
+ std::vector out2;
+ framework::TensorToVector(*out_tensor, ctx, &out2);
+
+ // compare
+ ASSERT_EQ(out1[0], out2[0]);
ASSERT_EQ(out1[0], expect);
delete engine;
@@ -85,8 +95,8 @@ void Compare(float input, float expect) {
}
TEST(OpConverter, ConvertRelu) {
- Compare(1, 1); // relu(1) = 1
- Compare(-5, 0); // relu(-5) = 0
+ Compare("relu", 1, 1); // relu(1) = 1
+ Compare("relu", -5, 0); // relu(-5) = 0
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
index afcc516e6b76d58e37ce0e60746704cf3933fac7..8f91309a0a00d5131268f026c319e25ba3cb964a 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
@@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
-#include
-
namespace paddle {
namespace inference {
namespace tensorrt {
-class EngineInputConverterTester : public ::testing::Test {
- public:
- void SetUp() override { tensor.Resize({10, 10}); }
+void IOConverterTester(const platform::DeviceContext& ctx) {
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
- framework::LoDTensor tensor;
-};
+ // init fluid in_tensor
+ framework::LoDTensor in_tensor;
+ in_tensor.Resize({10, 10});
+ auto place = ctx.GetPlace();
+ in_tensor.mutable_data(place);
+ std::vector init;
+ for (int64_t i = 0; i < 10 * 10; ++i) {
+ init.push_back(i);
+ }
+ framework::TensorFromVector(init, ctx, &in_tensor);
-TEST_F(EngineInputConverterTester, DefaultCPU) {
+ // init tensorrt buffer
void* buffer;
- tensor.mutable_data(platform::CPUPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+ size_t size = in_tensor.memory_size();
+ ASSERT_EQ(cudaMalloc(&buffer, size), 0);
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+ // convert fluid in_tensor to tensorrt buffer
+ EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream);
+
+ // convert tensorrt buffer to fluid out_tensor
+ framework::LoDTensor out_tensor;
+ out_tensor.Resize({10, 10});
+ out_tensor.mutable_data(place);
+ EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream);
+
+ // compare in_tensor and out_tensor
+ std::vector result;
+ framework::TensorToVector(out_tensor, ctx, &result);
+ EXPECT_EQ(init.size(), result.size());
+ for (size_t i = 0; i < init.size(); i++) {
+ EXPECT_EQ(init[i], result[i]);
+ }
+ cudaStreamDestroy(stream);
}
-TEST_F(EngineInputConverterTester, DefaultGPU) {
- void* buffer;
- tensor.mutable_data(platform::CUDAPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+TEST(EngineIOConverterTester, DefaultCPU) {
+ platform::CPUPlace place;
+ platform::CPUDeviceContext ctx(place);
+ IOConverterTester(ctx);
+}
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+TEST(EngineIOConverterTester, DefaultGPU) {
+ platform::CUDAPlace place;
+ platform::CUDADeviceContext ctx(place);
+ IOConverterTester(ctx);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
index c4fd1e298b0daea85db2a407d04ad2d7bcdee0f0..60c761c5281e2f535aab0200c93fb738addcdb87 100644
--- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
+++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
@@ -16,7 +16,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
-DEFINE_string(data_set, "cifar10", "Data set to test");
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
@@ -35,19 +34,19 @@ TEST(inference, image_classification) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+ const bool is_combined = false;
+ std::vector> feed_target_shapes =
+ GetFeedTargetShapes(dirname, is_combined);
+
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
- if (FLAGS_data_set == "cifar10") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32},
- static_cast(0), static_cast(1));
- } else if (FLAGS_data_set == "imagenet") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224},
- static_cast(0), static_cast(1));
- } else {
- LOG(FATAL) << "Only cifar10 or imagenet is supported.";
- }
-
+ feed_target_shapes[0][0] = FLAGS_batch_size;
+ paddle::framework::DDim input_dims =
+ paddle::framework::make_ddim(feed_target_shapes[0]);
+ LOG(INFO) << input_dims;
+ SetupTensor(&input, input_dims, static_cast(0),
+ static_cast(1));
std::vector cpu_feeds;
cpu_feeds.push_back(&input);
@@ -60,7 +59,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
LOG(INFO) << output1.dims();
}
@@ -73,7 +72,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- GPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined);
LOG(INFO) << output2.dims();
if (!FLAGS_skip_cpu) {
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index af2a7a5620487a10c1df6152fc4e4bf67b150752..b02e5c99f00eaf03c3753e43575cbc67e834774e 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
+std::unique_ptr InitProgram(
+ paddle::framework::Executor* executor, paddle::framework::Scope* scope,
+ const std::string& dirname, const bool is_combined = false) {
+ std::unique_ptr inference_program;
+ if (is_combined) {
+ // All parameters are saved in a single file.
+ // Hard-coding the file names of program and parameters in unittest.
+ // The file names should be consistent with that used in Python API
+ // `fluid.io.save_inference_model`.
+ std::string prog_filename = "__model_combined__";
+ std::string param_filename = "__params_combined__";
+ inference_program =
+ paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
+ dirname + "/" + param_filename);
+ } else {
+ // Parameters are saved in separate files sited in the specified
+ // `dirname`.
+ inference_program = paddle::inference::Load(executor, scope, dirname);
+ }
+ return inference_program;
+}
+
+std::vector> GetFeedTargetShapes(
+ const std::string& dirname, const bool is_combined = false) {
+ auto place = paddle::platform::CPUPlace();
+ auto executor = paddle::framework::Executor(place);
+ auto* scope = new paddle::framework::Scope();
+
+ auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
+ auto& global_block = inference_program->Block(0);
+
+ const std::vector& feed_target_names =
+ inference_program->GetFeedTargetNames();
+ std::vector> feed_target_shapes;
+ for (size_t i = 0; i < feed_target_names.size(); ++i) {
+ auto* var = global_block.FindVar(feed_target_names[i]);
+ std::vector var_shape = var->GetShape();
+ feed_target_shapes.push_back(var_shape);
+ }
+
+ delete scope;
+ return feed_target_shapes;
+}
+
template
void TestInference(const std::string& dirname,
const std::vector& cpu_feeds,
@@ -124,22 +168,7 @@ void TestInference(const std::string& dirname,
paddle::platform::RecordEvent record_event(
"init_program",
paddle::platform::DeviceContextPool::Instance().Get(place));
-
- if (is_combined) {
- // All parameters are saved in a single file.
- // Hard-coding the file names of program and parameters in unittest.
- // The file names should be consistent with that used in Python API
- // `fluid.io.save_inference_model`.
- std::string prog_filename = "__model_combined__";
- std::string param_filename = "__params_combined__";
- inference_program = paddle::inference::Load(
- &executor, scope, dirname + "/" + prog_filename,
- dirname + "/" + param_filename);
- } else {
- // Parameters are saved in separate files sited in the specified
- // `dirname`.
- inference_program = paddle::inference::Load(&executor, scope, dirname);
- }
+ inference_program = InitProgram(&executor, scope, dirname, is_combined);
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index c14a2b7786f9f7c06d59479d3bbce9c5d542e495..d38a9ce58726a1d045d6905354b0b592166c0110 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -186,6 +186,11 @@ endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE)
+ if(WITH_GPU)
+ op_library(gen_nccl_id_op DEPS nccl_common)
+ else()
+ set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
+ endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
+ cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else()
- set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
+ set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
+ s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h
index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644
--- a/paddle/fluid/operators/detail/grpc_client.h
+++ b/paddle/fluid/operators/detail/grpc_client.h
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
- explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; }
+ explicit BaseProcessor(std::shared_ptr ch) {
+ context_ = nullptr;
+ }
virtual ~BaseProcessor() {}
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
- RequestSendCallBack response_call_back_ = NULL;
+ RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function
diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc
index e6ee28ea8d920ef80fead258a9bd0d5f6762c879..d09f8479b765ad26cc202bfdb2692828213c7956 100644
--- a/paddle/fluid/operators/detail/grpc_server.cc
+++ b/paddle/fluid/operators/detail/grpc_server.cc
@@ -306,7 +306,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
- dev_ctx_, executor_, program_, prefetch_ctx_);
+ dev_ctx_, executor_, program_, prefetch_ctx_.get());
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h
index 7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84..238aaa29634a7eff65429c27aa3538a185723eb2 100644
--- a/paddle/fluid/operators/detail/grpc_server.h
+++ b/paddle/fluid/operators/detail/grpc_server.h
@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
+ ~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
@@ -63,8 +64,9 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
- void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
- prefetch_ctx_ = prepared;
+ void SetPrefetchPreparedCtx(
+ std::unique_ptr prepared) {
+ prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
@@ -115,7 +117,7 @@ class AsyncGRPCServer final {
std::unique_ptr t_get_;
std::unique_ptr t_prefetch_;
- framework::ExecutorPrepareContext *prefetch_ctx_;
+ std::unique_ptr prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc
index 25b95d608d10d6e456d5f563ce9fbe35d812cb0f..b8db0ad987cdfaec1fc9236c3f26e88891376dce 100644
--- a/paddle/fluid/operators/detail/grpc_server_test.cc
+++ b/paddle/fluid/operators/detail/grpc_server_test.cc
@@ -100,7 +100,7 @@ void StartServer(const std::string& endpoint) {
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
- rpc_service_->SetPrefetchPreparedCtx(prepared.get());
+ rpc_service_->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto
index fffa9ae7a43ea5cd7b2bda6fbbf6ef9f7d23009d..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644
--- a/paddle/fluid/operators/detail/send_recv.proto
+++ b/paddle/fluid/operators/detail/send_recv.proto
@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
+ NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc
index 1a8a1af20fa446dbd537944409ef0ca1e3e9116f..07c43554bc6a0d71d688a5a5772d0ab3d2de319a 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.cc
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include
#include // NOLINT
@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else if (var->IsType()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
+#ifdef PADDLE_WITH_CUDA
+ } else if (var->IsType()) {
+ request.set_type(::sendrecv::NCCL_ID);
+#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
+// NCCLID is copied directly to the message, return bytebuffer
+// with only one slice if serializing NCCLID.
+#ifdef PADDLE_WITH_CUDA
+ if (var->IsType()) {
+ e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
+ NCCL_UNIQUE_ID_BYTES);
+ const ncclUniqueId& uid = var->Get();
+ e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
+
+ // for serialize NCCL_ID
+ ::grpc::Slice slices(e.size());
+ memcpy(const_cast(slices.begin()), e.data(), e.size());
+ ::grpc::ByteBuffer tmp(&slices, 1);
+ msg->Swap(&tmp);
+ return;
+ }
+#endif
+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc
index 99602a05d023f30c2eed8df25e7534fdc9ef2ced..462e303096e609c6797ca8cc16266ec3621623fc 100644
--- a/paddle/fluid/operators/detail/variable_response.cc
+++ b/paddle/fluid/operators/detail/variable_response.cc
@@ -17,6 +17,9 @@
#include
#include
#include
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
- meta_.type() == sendrecv::LOD_TENSOR) &&
+ meta_.type() == sendrecv::LOD_TENSOR ||
+ meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
"meta info should be got first!");
@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
return tag;
}
+ if (meta_.type() == sendrecv::NCCL_ID) {
+#ifdef PADDLE_WITH_CUDA
+ auto* var = scope_->FindVar(meta_.varname());
+ if (var != nullptr) {
+ ncclUniqueId* id = var->GetMutable();
+ if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
+ num_bytes)) {
+ return tag;
+ }
+ }
+ break;
+#else
+ PADDLE_THROW("Not compiled with CUDA!");
+#endif
+ }
+
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a5678f63466d368b3dd59380c18f9625cabd368b
--- /dev/null
+++ b/paddle/fluid/operators/gen_nccl_id_op.cc
@@ -0,0 +1,128 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/threadpool.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/detail/grpc_server.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+
+namespace paddle {
+namespace operators {
+
+class GenNCCLIdOp : public framework::OperatorBase {
+ public:
+ GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
+ const framework::VariableNameMap& outputs,
+ const framework::AttributeMap& attrs)
+ : OperatorBase(type, inputs, outputs, attrs) {}
+
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ // put nccl id in CPUPlace
+ auto& dev_ctx = *pool.Get(platform::CPUPlace());
+ int trainer_id = Attr("trainer_id");
+ framework::Scope& local_scope = scope.NewScope();
+
+ if (trainer_id == 0) {
+ GenerateAndSend(&local_scope, dev_ctx);
+ } else {
+ GetIdByServer(&local_scope, dev_ctx);
+ }
+ }
+
+ private:
+ void GenerateAndSend(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ auto var = scope->FindVar(NCCL_ID_VARNAME);
+ PADDLE_ENFORCE_NOT_NULL(var);
+ auto id = var->GetMutable();
+ PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
+
+ std::vector endpoint_list =
+ Attr>("endpoint_list");
+ detail::RPCClient client;
+ for (auto& ep : endpoint_list) {
+ VLOG(3) << "sending nccl id to " << ep;
+ client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
+ }
+ client.Wait();
+ VLOG(3) << "sending completed...";
+ }
+
+ void GetIdByServer(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ std::string endpoint = Attr("endpoint");
+ // NOTE: Can not use unique_ptr here because the default
+ // deleter will call GRPC Server's base class's dtor and
+ // that will cause a wired crash.
+ detail::AsyncGRPCServer rpc_service(endpoint, true);
+ framework::ProgramDesc empty_program;
+ framework::Executor executor(dev_ctx.GetPlace());
+ rpc_service.SetScope(scope);
+ rpc_service.SetDevCtx(&dev_ctx);
+ rpc_service.SetProgram(&empty_program);
+ rpc_service.SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
+ rpc_service.SetCond(0);
+ VLOG(3) << "start getting nccl id from trainer 0...";
+ auto recv = rpc_service.Get();
+ VLOG(3) << "got nccl id and stop server...";
+ rpc_service.ShutDown();
+ VLOG(3) << "rpc server stopped";
+ server_thread.join();
+ }
+};
+
+class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
+ AddComment(R"DOC(
+GenNCCLId operator
+
+For trainer 0: generate a new UniqueId and send it to all the other trainers.
+For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
+)DOC");
+ AddAttr("endpoint",
+ "(string), e.g. 127.0.0.1:6175 "
+ "current listen endpoint");
+ AddAttr>(
+ "endpoint_list",
+ "['trainer1_ip:port', 'trainer2_ip:port', ...] "
+ "list of trainer endpoints start from trainer 1")
+ .SetDefault({});
+ AddAttr("trainer_id",
+ "(int default 0) "
+ "The index of the trainer in distributed training.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc
index a29e0cd52cfccf242a6490822234045e6eb66c0f..abc88d3eb1514e159f4a880f44ecc0c0960a73d9 100644
--- a/paddle/fluid/operators/listen_and_serv_op.cc
+++ b/paddle/fluid/operators/listen_and_serv_op.cc
@@ -322,8 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
- rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
- prefetch_prepared.release();
+ rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc
index b5522dd246f250f02d69c0ba749ae6043eb810d6..0522a94195786c767194ec727d982a60451e7c62 100644
--- a/paddle/fluid/operators/load_combine_op.cc
+++ b/paddle/fluid/operators/load_combine_op.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include
-
+#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
@@ -31,6 +31,7 @@ class LoadCombineOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr("file_path");
+ auto load_as_fp16 = Attr("load_as_fp16");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast(fin),
@@ -59,17 +60,25 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx);
- if (platform::is_gpu_place(place)) {
- // copy CPU to GPU
- framework::LoDTensor cpu_tensor;
- cpu_tensor.ShareDataWith(*tensor);
- cpu_tensor.set_lod(tensor->lod());
-
- // reset tensor
+ auto in_dtype = framework::ToDataType(tensor->type());
+ auto out_dtype =
+ load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
+
+ if (in_dtype != out_dtype) {
+ // convert to float16 tensor
+ auto in_kernel_type = framework::OpKernelType(in_dtype, place);
+ auto out_kernel_type = framework::OpKernelType(out_dtype, place);
+ framework::LoDTensor fp16_tensor;
+ // copy LoD info to the new tensor
+ fp16_tensor.set_lod(tensor->lod());
+ framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
+ &fp16_tensor);
+
+ // reset output tensor
out_var->Clear();
tensor = out_var->GetMutable();
- tensor->set_lod(cpu_tensor.lod());
- TensorCopy(cpu_tensor, place, dev_ctx, tensor);
+ tensor->set_lod(fp16_tensor.lod());
+ tensor->ShareDataWith(fp16_tensor);
}
}
}
@@ -82,6 +91,13 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"Out",
"(vector) The output LoDTensors that will be read from the input file.")
.AsDuplicable();
+ AddAttr(
+ "load_as_fp16",
+ "(boolean, default false)"
+ "If true, the tensor will be first loaded and then "
+ "converted to float16 data type. Otherwise, the tensor will be "
+ "directly loaded without data type conversion.")
+ .SetDefault(false);
AddAttr("file_path",
"(string) "
"LoDTensors will be loaded from \"file_path\".")
diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h
index 0abda999a52bcbb94e6503692bd11aff26e849ba..62e6307ae9f4236a38c49daaf09fc05c54268159 100644
--- a/paddle/fluid/operators/math/sequence2batch.h
+++ b/paddle/fluid/operators/math/sequence2batch.h
@@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
- PADDLE_ENFORCE_GT(lods.size(), 2UL);
- PADDLE_ENFORCE_EQ(lods[1].size(),
- static_cast(lod_tensor.dims()[0]));
+ PADDLE_ENFORCE_GT(lods.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ lods[1].size(), static_cast(lod_tensor.dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_batch;
to_batch(context, lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
- auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
+ auto lod = lods[0];
+
std::vector seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
@@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod();
- PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
- PADDLE_ENFORCE_EQ(in_lod[1].size(),
- static_cast(lod_tensor->dims()[0]));
+ PADDLE_ENFORCE_GT(in_lod.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ in_lod[1].size(), static_cast(lod_tensor->dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false);
}
diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h
index ccd7063fe69e0f21b4d2a821bb70902b39c9b9de..3dd8c7c11eca241e747bfa129962032d882ce44c 100644
--- a/paddle/fluid/operators/reshape_op.h
+++ b/paddle/fluid/operators/reshape_op.h
@@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
if (unk_dim_idx != -1) {
- output_shape[unk_dim_idx] = -in_size / capacity;
- // in_size < 0 and is un-determinate in compile time, skip the check,
- // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
- // capacity = -24, in_size = -8, output_shape[0] = 0
- // the following check will fail.
if (in_size > 0) {
+ // in_size < 0 and is un-determinate in compile time, skip the check,
+ // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
+ // capacity = -24, in_size = -8, output_shape[0] = 0
+ // the following check will fail.
+ output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
+ } else {
+ output_shape[unk_dim_idx] = -1;
}
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
@@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output("Out");
auto *in = ctx.Input("X");
- auto *shape_tensor = ctx.Input("Shape");
+
+ auto *shape_tensor = ctx.HasInput("Shape")
+ ? ctx.Input("Shape")
+ : nullptr;
framework::DDim out_dims = out->dims();
diff --git a/paddle/fluid/operators/save_load_combine_op_test.cc b/paddle/fluid/operators/save_load_combine_op_test.cc
index 47618c51d98eb9f58988f82c0aee0083565d81a6..4743e0d9499b111d8baa921dbb245431713fd7a8 100644
--- a/paddle/fluid/operators/save_load_combine_op_test.cc
+++ b/paddle/fluid/operators/save_load_combine_op_test.cc
@@ -139,8 +139,9 @@ TEST(SaveLoadCombineOp, CPU) {
CheckValues(expect4, actual4, expect_lod4, actual_lod4, numel4);
}
-// FP16 version of SaveLoadCombineOp Test
-TEST(SaveLoadCombineFP16Op, CPU) {
+// FP16 version of SaveLoadCombineOp Test, only altering the saving aspect
+// to save as FP16.
+TEST(SaveCombineFP16Op, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
@@ -169,7 +170,7 @@ TEST(SaveLoadCombineFP16Op, CPU) {
20, 50, lod4, "test_var4", place, &scope, &expect_lod4);
// Set attributes
- std::string filename = "check_tensor_fp16.ls";
+ std::string filename = "check_tensor_fp16_save.ls";
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", std::string(filename)});
attrs.insert({"save_as_fp16", true});
@@ -216,6 +217,89 @@ TEST(SaveLoadCombineFP16Op, CPU) {
actual_lod4, numel4);
}
+// FP16 version of SaveLoadCombineOp Test, only altering the loading aspect
+// to load tensors with FP16 precision.
+TEST(LoadCombineFP16Op, CPU) {
+ paddle::framework::Scope scope;
+ paddle::platform::CPUPlace place;
+
+ std::vector lod1 = {0, 1, 2, 3, 10};
+ int numel1 = 100;
+ paddle::framework::LoD expect_lod1;
+ float* expect1 = CreateForSaveCombineOp(
+ 10, 10, lod1, "test_var1", place, &scope, &expect_lod1);
+
+ std::vector lod2 = {0, 2, 5, 10};
+ int numel2 = 200;
+ paddle::framework::LoD expect_lod2;
+ float* expect2 = CreateForSaveCombineOp(
+ 10, 20, lod2, "test_var2", place, &scope, &expect_lod2);
+
+ std::vector lod3 = {0, 20};
+ int numel3 = 4000;
+ paddle::framework::LoD expect_lod3;
+ float* expect3 = CreateForSaveCombineOp(
+ 20, 200, lod3, "test_var3", place, &scope, &expect_lod3);
+
+ std::vector lod4 = {0, 1, 20};
+ int numel4 = 1000;
+ paddle::framework::LoD expect_lod4;
+ float* expect4 = CreateForSaveCombineOp(
+ 20, 50, lod4, "test_var4", place, &scope, &expect_lod4);
+
+ // Set attributes
+ std::string filename = "check_tensor_fp16_load.ls";
+ paddle::framework::AttributeMap attrs;
+ attrs.insert({"file_path", std::string(filename)});
+
+ // Run the save_combine_op
+ auto save_combine_op = paddle::framework::OpRegistry::CreateOp(
+ "save_combine",
+ {{"X", {"test_var1", "test_var2", "test_var3", "test_var4"}}}, {}, attrs);
+ save_combine_op->Run(scope, place);
+
+ // Set up output vars
+ auto load_var1 = scope.Var("out_var1");
+ auto load_var2 = scope.Var("out_var2");
+ auto load_var3 = scope.Var("out_var3");
+ auto load_var4 = scope.Var("out_var4");
+
+ attrs.insert({"load_as_fp16", true});
+ // Run the load_combine_op
+ auto load_combine_op = paddle::framework::OpRegistry::CreateOp(
+ "load_combine", {},
+ {{"Out", {"out_var1", "out_var2", "out_var3", "out_var4"}}}, attrs);
+ load_combine_op->Run(scope, place);
+
+ auto* target1 = load_var1->GetMutable();
+ auto* target2 = load_var2->GetMutable();
+ auto* target3 = load_var3->GetMutable();
+ auto* target4 = load_var4->GetMutable();
+
+ paddle::framework::LoD actual_lod1, actual_lod2, actual_lod3, actual_lod4;
+ paddle::platform::float16* actual1 =
+ GetValuesAfterLoadCombineOp(target1, scope,
+ &actual_lod1);
+ paddle::platform::float16* actual2 =
+ GetValuesAfterLoadCombineOp(target2, scope,
+ &actual_lod2);
+ paddle::platform::float16* actual3 =
+ GetValuesAfterLoadCombineOp(target3, scope,
+ &actual_lod3);
+ paddle::platform::float16* actual4 =
+ GetValuesAfterLoadCombineOp(target4, scope,
+ &actual_lod4);
+
+ CheckValues(expect1, actual1, expect_lod1,
+ actual_lod1, numel1);
+ CheckValues(expect2, actual2, expect_lod2,
+ actual_lod2, numel2);
+ CheckValues(expect3, actual3, expect_lod3,
+ actual_lod3, numel3);
+ CheckValues(expect4, actual4, expect_lod4,
+ actual_lod4, numel4);
+}
+
// Test with original SaveLoadTest
TEST(SaveLoadTestWithCombineOp, CPU) {
paddle::framework::Scope scope;
diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc
new file mode 100644
index 0000000000000000000000000000000000000000..bbae1d54aa3524fd45cb8ab13c86df8d54b8e643
--- /dev/null
+++ b/paddle/fluid/operators/test_send_nccl_id.cc
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include // NOLINT
+
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/listen_and_serv_op.h"
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+#include "paddle/fluid/string/printf.h"
+
+USE_NO_KERNEL_OP(listen_and_serv);
+
+namespace f = paddle::framework;
+namespace p = paddle::platform;
+namespace m = paddle::operators::math;
+namespace detail = paddle::operators::detail;
+namespace string = paddle::string;
+
+std::unique_ptr rpc_service;
+
+void StartServer(std::atomic* initialized) {
+ f::Scope scope;
+ p::CPUPlace place;
+ scope.Var(NCCL_ID_VARNAME);
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
+
+ f::ProgramDesc empty_program;
+ f::Executor executor(dev_ctx.GetPlace());
+ rpc_service->SetScope(&scope);
+ rpc_service->SetDevCtx(&dev_ctx);
+ rpc_service->SetProgram(&empty_program);
+ rpc_service->SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
+ *initialized = true;
+ rpc_service->SetCond(0);
+ auto recv = rpc_service->Get();
+ LOG(INFO) << "got nccl id and stop server...";
+ rpc_service->ShutDown();
+ server_thread.join();
+}
+
+TEST(SendNcclId, Normal) {
+ std::atomic initialized{false};
+ std::thread server_thread(StartServer, &initialized);
+ while (!initialized) {
+ }
+ // wait server to start
+ // sleep(2);
+ rpc_service->WaitServerReady();
+
+ f::Scope scope;
+ p::CPUPlace place;
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ auto var = scope.Var(NCCL_ID_VARNAME);
+ // var->SetType(f::proto::VarType_Type_RAW);
+ auto id = var->GetMutable();
+ p::dynload::ncclGetUniqueId(id);
+
+ int port = rpc_service->GetSelectedPort();
+ std::string ep = string::Sprintf("127.0.0.1:%d", port);
+ detail::RPCClient client;
+
+ client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
+ client.Wait();
+ server_thread.join();
+ auto* ptr = rpc_service.release();
+ delete ptr;
+}
diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index 0013597fd516d15c7d502370eec77e1a6a5dca88..e30c1a9ebf08365a9856fb32b1ce5790869e2b33 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -14,12 +14,15 @@
#pragma once
+#include
#include // NOLINT
#include
#include
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
+#define NCCL_ID_VARNAME "NCCLID"
+
namespace paddle {
namespace platform {
@@ -73,7 +76,9 @@ struct NCCLContextMap {
std::unordered_map contexts_;
std::vector order_;
- explicit NCCLContextMap(const std::vector &places) {
+ explicit NCCLContextMap(const std::vector &places,
+ ncclUniqueId *nccl_id = nullptr,
+ size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
@@ -85,18 +90,34 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
- if (places.size() > 1) {
- std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ if (places.size() <= 1) {
+ return;
+ }
+ std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ // if pass nccl_id here, can assume we are doing multi node training
+ if (nccl_id == nullptr) {
+ std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
+ comms.get(), static_cast(order_.size()), order_.data()));
+ } else {
+ PADDLE_ENFORCE_GT(num_trainers, 1);
+ // TODO(wuyi): need to ensure each node have same number of GPUs
{
- std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
- PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
- comms.get(), static_cast(order_.size()), order_.data()));
- }
- int i = 0;
- for (auto &dev_id : order_) {
- contexts_.at(dev_id).comm_ = comms[i++];
+ int nranks = num_trainers * order_.size();
+ NCCLGroupGuard gurad;
+ for (auto &gpu_id : order_) {
+ int rank = trainer_id * order_.size() + gpu_id;
+ VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
+ PADDLE_ENFORCE(cudaSetDevice(gpu_id));
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
+ comms.get() + gpu_id, nranks, *nccl_id, rank));
+ }
}
}
+ int i = 0;
+ for (auto &dev_id : order_) {
+ contexts_.at(dev_id).comm_ = comms[i++];
+ }
}
NCCLContextMap(const NCCLContextMap &other) = delete;
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index 3e2eed31b446b83843fba943e4f2bc9e3787d7f6..b62291a99d34457dd17bf2bcafc1fc611419f086 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle.
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector &local_scopes,
bool allow_op_delay, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards) {
+ bool balance_parameter_opt_between_cards, size_t num_trainers,
+ size_t trainer_id) {
new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale,
- balance_parameter_opt_between_cards);
+ balance_parameter_opt_between_cards, num_trainers, trainer_id);
})
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec* to Python use reference policy.
diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py
index 7af6ed1463ab737e871da487f2a687301652ef2d..32b1b65bd97ef1e512a5880843509611b606f52d 100644
--- a/python/paddle/fluid/backward.py
+++ b/python/paddle/fluid/backward.py
@@ -480,6 +480,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
program.current_block_idx = current_block_idx
program.sync_with_cpp()
+ # FIXME(zcd): prevent loss.grad optimized by mem_opt.
+ loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
if parameter_list is not None:
parameters = parameter_list
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index 28e54f5492e7b04a1406e319cecf977d4a55725e..38c765938fe9d7b2103bfdd926874c485d0ff4dc 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -489,7 +489,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine', 'ncclInit', 'channel_create', 'channel_close',
- 'channel_send', 'channel_recv', 'select'
+ 'channel_send', 'channel_recv', 'select', 'gen_nccl_id'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index 5b43f860e7075745bbf6e76c2f9d0e9a87a86db0..7358c4b60e87893b9c04e3da2221dfb69d3ba0c7 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -31,7 +31,9 @@ class ParallelExecutor(object):
allow_op_delay=False,
share_vars_from=None,
use_default_grad_scale=True,
- balance_parameter_opt_between_cards=False):
+ balance_parameter_opt_between_cards=False,
+ num_trainers=1,
+ trainer_id=0):
"""
ParallelExecutor can run program in parallel.
@@ -55,6 +57,11 @@ class ParallelExecutor(object):
balance_parameter_opt_between_cards(bool, default True): Whether
updating different gradients on different cards. Currently, it
is not recommended.
+ num_trainers(int, default 1): If greater than 1, NCCL will be
+ initialized with multpile rank of nodes, each node should have
+ same number of GPUs. Distributed training will be enabled then.
+ trainer_id(int, default 0): Must use together with num_trainers.
+ trainer_id is the "rank" of current node starts from 0.
Returns:
A ParallelExecutor object.
@@ -134,8 +141,9 @@ class ParallelExecutor(object):
local_scopes,
allow_op_delay,
use_default_grad_scale,
- balance_parameter_opt_between_cards)
-
+ balance_parameter_opt_between_cards,
+ num_trainers,
+ trainer_id)
self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None):
diff --git a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
index 9ab00325a2eef3bbc79757ad1a3e6f8511c49552..c2a15bdb3b17b65fe861dd429f548074c13e2f09 100644
--- a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
+++ b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
@@ -6,4 +6,5 @@ foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
+add_subdirectory(fit_a_line)
add_subdirectory(recognize_digits)
diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..673c965b662a022739f8d489c331f4de9455a926
--- /dev/null
+++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt
@@ -0,0 +1,7 @@
+file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
+string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
+
+# default test
+foreach(src ${TEST_OPS})
+ py_test(${src} SRCS ${src}.py)
+endforeach()
diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9bbb52d769282460c571ebc51d5eff18de3114
--- /dev/null
+++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+import contextlib
+import numpy
+import unittest
+
+# train reader
+BATCH_SIZE = 20
+
+train_reader = paddle.batch(
+ paddle.reader.shuffle(
+ paddle.dataset.uci_housing.train(), buf_size=500),
+ batch_size=BATCH_SIZE)
+
+test_reader = paddle.batch(
+ paddle.reader.shuffle(
+ paddle.dataset.uci_housing.test(), buf_size=500),
+ batch_size=BATCH_SIZE)
+
+
+def inference_program():
+ x = fluid.layers.data(name='x', shape=[13], dtype='float32')
+ y_predict = fluid.layers.fc(input=x, size=1, act=None)
+ return y_predict
+
+
+def linear():
+ y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+ y_predict = inference_program()
+
+ loss = fluid.layers.square_error_cost(input=y_predict, label=y)
+ avg_loss = fluid.layers.mean(loss)
+
+ return avg_loss
+
+
+def train(use_cuda, save_dirname):
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+
+ trainer = fluid.Trainer(
+ train_func=linear,
+ infer_func=inference_program,
+ place=place,
+ optimizer=fluid.optimizer.SGD(learning_rate=0.001))
+
+ def event_handler(event):
+ if isinstance(event, fluid.EndEpochEvent):
+ test_metrics = trainer.test(
+ reader=test_reader, feed_order=['x', 'y'])
+ print test_metrics
+ '''
+
+ ...
+ ['25.768919467926025']
+ ['15.343549569447836']
+ ...
+
+ '''
+ if float(test_metrics[0]) < 20.0:
+ if save_dirname is not None:
+ # NOT clear yet
+ # fluid.io.save_inference_model(save_dirname, ['x'], [y_predict])
+ # trainer.save_params(save_dirname)
+ # https://github.com/PaddlePaddle/Paddle/pull/10445
+ trainer.save_inference_model(save_dirname)
+ return
+
+ trainer.train(
+ reader=train_reader,
+ num_epochs=100,
+ event_handler=event_handler,
+ feed_order=['x', 'y'])
+
+
+# infer
+def infer(use_cuda, save_dirname=None):
+ if save_dirname is None:
+ return
+
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+ inferencer = fluid.Inferencer(param_path=save_dirname, place=place)
+
+ batch_size = 10
+ tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
+
+ results = inferencer.infer({'x': tensor_x})
+ print("infer results: ", results[0])
+
+
+def main(use_cuda):
+ if use_cuda and not fluid.core.is_compiled_with_cuda():
+ return
+
+ # Directory for saving the trained model
+ save_dirname = "fit_a_line.inference.model"
+
+ train(use_cuda, save_dirname)
+ infer(use_cuda, save_dirname)
+
+
+class TestFitALine(unittest.TestCase):
+ def test_cpu(self):
+ with self.program_scope_guard():
+ with fluid.unique_name.guard():
+ main(use_cuda=False)
+
+ def test_cuda(self):
+ with self.program_scope_guard():
+ with fluid.unique_name.guard():
+ main(use_cuda=True)
+
+ @contextlib.contextmanager
+ def program_scope_guard(self):
+ prog = fluid.Program()
+ startup_prog = fluid.Program()
+ scope = fluid.core.Scope()
+ with fluid.scope_guard(scope):
+ with fluid.program_guard(prog, startup_prog):
+ yield
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py
index 241778e303036d068dc0a40e4574a02eb97ad134..792ed7368d646cd9dff9255eb402b6a9b84f69a6 100644
--- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py
+++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py
@@ -170,7 +170,7 @@ def train(word_dict,
assert save_dirname is None
adagrad = fluid.optimizer.Adagrad(learning_rate=0.002)
- optimize_ops, params_grads = adagrad.minimize(cost)
+ adagrad.minimize(cost)
train_data = paddle.batch(
paddle.reader.shuffle(
diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py
index ecb34699af0dc14782601702ab8afedbca7e1bfd..b1a6b524d33cae97c8982ffb8f780b1b07761a09 100644
--- a/python/paddle/fluid/tests/book/test_fit_a_line.py
+++ b/python/paddle/fluid/tests/book/test_fit_a_line.py
@@ -33,7 +33,7 @@ def train(use_cuda, save_dirname, is_local):
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 20
diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py
index dbcdb5766e7d20efdb12da0ea4c6f005d903849b..0f3a4c9242a81a3c1fb90268245715a8e59a207a 100644
--- a/python/paddle/fluid/tests/book/test_image_classification.py
+++ b/python/paddle/fluid/tests/book/test_image_classification.py
@@ -125,7 +125,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
- optimize_ops, params_grads = optimizer.minimize(avg_cost)
+ optimizer.minimize(avg_cost)
BATCH_SIZE = 128
PASS_NUM = 1
diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py
index 0faba33032d5dfc0b751a5191e7b2ae0c1f172bf..09793760e5504c04ad4b0bfac5c5d7b7047cf85d 100644
--- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py
+++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py
@@ -175,7 +175,7 @@ def train(use_cuda, save_dirname=None, is_local=True):
decay_steps=100000,
decay_rate=0.5,
staircase=True))
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
# TODO(qiao)
# add dependency track and move this config before optimizer
diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py
index 46c6b9c29a265741a99655d5ac29244798f6fec2..e8a75f473f62df528b7f39bf5f9085076e005c25 100644
--- a/python/paddle/fluid/tests/book/test_machine_translation.py
+++ b/python/paddle/fluid/tests/book/test_machine_translation.py
@@ -185,7 +185,7 @@ def train_main(use_cuda, is_sparse, is_local=True):
learning_rate=1e-4,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
- optimize_ops, params_grads = optimizer.minimize(avg_cost)
+ optimizer.minimize(avg_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py
index c115aa4d7d6b514f9207543730e5e76cb0d2040c..578b1162fbd7e3a1b1c0cc934406818f2e07e019 100644
--- a/python/paddle/fluid/tests/book/test_recognize_digits.py
+++ b/python/paddle/fluid/tests/book/test_recognize_digits.py
@@ -95,7 +95,7 @@ def train(nn_type,
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
- optimize_ops, params_grads = optimizer.minimize(avg_loss)
+ optimizer.minimize(avg_loss)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py
index d022dedbff805d597b68b5a47f7931f2dd946615..7be924f762ddeb045dda890dbfdcd96a65449553 100644
--- a/python/paddle/fluid/tests/book/test_recommender_system.py
+++ b/python/paddle/fluid/tests/book/test_recommender_system.py
@@ -160,7 +160,7 @@ def train(use_cuda, save_dirname, is_local=True):
test_program = fluid.default_main_program().clone(for_test=True)
sgd_optimizer = SGDOptimizer(learning_rate=0.2)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py
index 6dec0f6857e86b4b9c1c67af934aa9bfdb1c3df7..30e1a5040cc92b02bbbf90dac97001812ec90134 100644
--- a/python/paddle/fluid/tests/book/test_word2vec.py
+++ b/python/paddle/fluid/tests/book/test_word2vec.py
@@ -101,7 +101,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost = fluid.layers.mean(pd())
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f8c4f3f0167632bb4a3d454ab026ba73a8f305
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+import paddle.fluid.layers as layers
+from paddle.fluid.transpiler.distribute_transpiler import delete_ops
+import numpy
+
+
+class TestDistTranspiler(unittest.TestCase):
+ def setUp(self):
+ self.trainer_id = 0
+ self.trainers = 2
+ self.pservers = 2
+ self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
+ self.current_pserver_ep = "127.0.0.1:6174"
+
+ def net_conf(self):
+ x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
+
+ y_predict = fluid.layers.fc(input=x,
+ size=1000,
+ act=None,
+ param_attr=fluid.ParamAttr(name='fc_w'))
+
+ y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+
+ cost = fluid.layers.square_error_cost(input=y_predict, label=y)
+ avg_cost = fluid.layers.mean(cost)
+ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
+
+ optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ return optimize_ops, params_grads
+
+ def test_transpiler(self):
+ trainer = self.get_trainer()
+ pserver, startup = self.get_pserver(self.current_pserver_ep)
+
+ self.assertEqual([op.type for op in trainer.global_block().ops],
+ self.get_expect_trainer_ops())
+
+ self.assertEqual(len(pserver.blocks), 3)
+ # block0: listen_and_serv
+ self.assertEqual([op.type for op in pserver.blocks[0].ops],
+ ["listen_and_serv"])
+ # block2: optimize pass
+ self.assertEqual([op.type for op in pserver.blocks[1].ops],
+ ["sum", "scale", "sgd"])
+
+ # confirm startup program
+
+ self.assertEqual([op.type for op in startup.global_block().ops], [
+ "fill_constant", "fill_constant", "uniform_random", "uniform_random"
+ ])
+
+ # the variable #fc_w will be split into two blocks
+ fc_w_var = startup.global_block().var("fc_w.block1")
+ self.assertEqual(fc_w_var.shape, (500, 1000))
+
+ def get_main_program(self):
+ main = fluid.Program()
+
+ with fluid.program_guard(main):
+ self.net_conf()
+
+ return main
+
+ def get_expect_trainer_ops(self):
+ trainer = fluid.Program()
+
+ with fluid.program_guard(trainer):
+ optimize_ops, params_grads = self.net_conf()
+
+ delete_ops(trainer.global_block(), optimize_ops)
+ return [op.type for op in trainer.global_block().ops
+ ] + ["split_byref", "send", "concat"]
+
+ def get_trainer(self):
+ return self._transpiler_instance().get_trainer_program()
+
+ def get_pserver(self, ep):
+ t = self._transpiler_instance()
+ pserver = t.get_pserver_program(ep)
+ startup = t.get_startup_program(ep, pserver)
+ return pserver, startup
+
+ def _transpiler_instance(self):
+ main = self.get_main_program()
+ t = fluid.DistributeTranspiler()
+ t.transpile(
+ self.trainer_id,
+ program=main,
+ pservers=self.pserver_eps,
+ trainers=self.trainers)
+ return t
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_split_var.py b/python/paddle/fluid/tests/unittests/test_split_var.py
index 79d387f0066672058d1640f4e5fd28ed8913fe4c..0c5e8901b903375c7d4de32943e657b205d8fae9 100644
--- a/python/paddle/fluid/tests/unittests/test_split_var.py
+++ b/python/paddle/fluid/tests/unittests/test_split_var.py
@@ -21,15 +21,7 @@ import random
class TestSplitVar(unittest.TestCase):
- def test_check_output(self):
- # split below shapes to 10 servers
- shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10]]
- expected_sizes = [
- [15], [1024],
- [2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 784],
- [2040, 2040, 2040, 2040],
- [1150, 1150, 1150, 1150, 1150, 1150, 1100]
- ]
+ def check_split_output(self, shapes, expected_sizes, min_size):
var_list = []
program = fluid.Program()
for shape in shapes:
@@ -39,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
- blocks = split_dense_variable(var_list, 10)
+ blocks = split_dense_variable(var_list, 10, min_size)
all_sizes = []
for s in expected_sizes:
for s2 in s:
@@ -48,6 +40,25 @@ class TestSplitVar(unittest.TestCase):
varname, block_id, size = block_str.split(":")
self.assertEqual(int(size), all_sizes[i])
+ def test_1k(self):
+ shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10]]
+ expected_sizes = [
+ [15], [1024],
+ [2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 784],
+ [2040, 2040, 2040, 2040],
+ [1150, 1150, 1150, 1150, 1150, 1150, 1100]
+ ]
+
+ self.check_split_output(shapes, expected_sizes, 1024)
+
+ def test_check_output_8k(self):
+ shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10],
+ [6, 33, 33, 33]]
+ expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000],
+ [35937, 35937, 35937, 35937, 35937, 35937]]
+
+ self.check_split_output(shapes, expected_sizes, 8192)
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py
index 6d3c1b947f4acb1335b25e6eb0099d5d532c895a..413c36c5c41bbe0169f1c050ccdac040202d66df 100644
--- a/python/paddle/fluid/transpiler/__init__.py
+++ b/python/paddle/fluid/transpiler/__init__.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index b45cb987d896bd189531e97eb62bddbbee16069d..42ff0a9eb1112ed5709749e3867794c80be8f1d1 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -17,7 +17,7 @@ from __future__ import print_function
import math
import distributed_splitter as splitter
-from .. import core
+from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, \
Variable, Parameter, grad_var_name
@@ -93,30 +93,33 @@ def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
-def split_dense_variable(var_list,
- pserver_count,
- min_block_size=1024,
- max_block_size=1048576):
+def split_dense_variable(var_list, service_count, min_block_size=8192):
"""
- We may need to split dense tensor to one or more blocks and put
- them equally onto parameter server. One block is a sub-tensor
- aligned by dim[0] of the tensor.
-
- We need to have a minimal block size so that the calculations in
- the parameter server side can gain better performance. By default
- minimum block size is 1024. The max block size is used to prevent
- very large blocks that may cause send error.
- :return: A list of VarBlocks. Each VarBlock specifies a shard of
- the var.
+ We may need to split dense tensor to one or more blocks and put
+ them equally onto parameter server. One block is a sub-tensor
+ aligned by dim[0] of the tensor.
+
+ We need to have a minimal block size so that the calculations in
+ the parameter server side can gain better performance. By default
+ minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
+
+ Args:
+ var_list (list): List of variables.
+ service_count (int): Numel of pserver services. A pserver may have two
+ or more listening ports.
+ min_block_size (int): Minimum splitted block size.
+ Returns:
+ blocks (list[(varname, block_id, current_block_size)]): A list
+ of VarBlocks. Each VarBlock specifies a shard of the var.
"""
blocks = []
for var in var_list:
- split_count = pserver_count
+ split_count = service_count
var_numel = reduce(lambda x, y: x * y, var.shape)
max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
if max_pserver_count == 0:
max_pserver_count = 1
- if max_pserver_count < pserver_count:
+ if max_pserver_count < service_count:
split_count = max_pserver_count
block_size = int(math.ceil(var_numel / float(split_count)))
@@ -270,6 +273,7 @@ class DistributeTranspiler:
grad_var_mapping = self._append_split_op(program, grad_blocks)
param_var_mapping = self._create_vars_from_blocklist(program,
param_blocks)
+
# step3: Add gradients as send op inputs and parameters as send
# op outputs.
send_inputs = []
@@ -277,9 +281,11 @@ class DistributeTranspiler:
for b in grad_blocks: # append by order
varname, block_id, _ = b.split(":")
send_inputs.append(grad_var_mapping[varname][int(block_id)])
+
for b in param_blocks:
varname, block_id, _ = b.split(":")
send_outputs.append(param_var_mapping[varname][int(block_id)])
+
# let send_op know which endpoint to send which var to, eplist has the same
# order as send_inputs.
eplist = split_method(send_inputs, pserver_endpoints)
@@ -417,7 +423,7 @@ class DistributeTranspiler:
def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
- default_main_program())
+ self.origin_program)
else:
self._append_pserver_non_opt_ops(block, op)
@@ -751,9 +757,18 @@ class DistributeTranspiler:
Create vars for each split.
NOTE: only grads need to be named for different trainers, use
add_trainer_suffix to rename the grad vars.
- :return: A dict mapping from original var name to each var split.
+ Args:
+ program (ProgramDesc): ProgramDesc which gradients blong.
+ block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
+ add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
+ Returns:
+ var_mapping (dict(varname->[new_varname_variable])):A dict mapping
+ from original var name to each var split.
"""
+
+ # varname->[(block_id, current_block_size)]
block_map = dict()
+
var_mapping = dict()
for block_str in block_list:
varname, offset, size = block_str.split(":")
@@ -824,7 +839,16 @@ class DistributeTranspiler:
persistable=persistable)
def _append_split_op(self, program, gradblocks):
- # Split variables that need to be split and append respective ops
+ """
+ Split variables that need to be split and append respective ops
+ Args:
+ program (ProgramDesc): ProgramDesc that gradients blong.
+ gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
+ Returns:
+ var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
+ from original var name to each var split.
+ """
+
add_suffix = False
if self.trainer_num > 1:
add_suffix = True
@@ -1148,6 +1172,12 @@ class DistributeTranspiler:
return lr_ops
def _get_optimize_pass(self):
+ """
+ Get optimizer operators, paramters and gradients from origin_program
+ Returns:
+ opt_ops (list): optimize operators.
+ params_grads (dict): paramter->gradient.
+ """
block = self.origin_program.global_block()
opt_ops = []
params_grads = []
diff --git a/tools/manylinux1/README.md b/tools/manylinux1/README.md
index 898e00bd37c7b7bcbcb4a56476ff10c87381e47a..0e5905040175047f5b79939d97a3efcf38992944 100644
--- a/tools/manylinux1/README.md
+++ b/tools/manylinux1/README.md
@@ -28,3 +28,38 @@ git clone https://github.com/paddlepaddle/paddle
cd paddle/tools/manylinux1
REPO=[yourrepo] ./build_all.sh
```
+
+## Build PaddlePaddle for the different Python ABIs
+
+Choose one of the following Python ABI and set the correct environment variables.
+
+- cp27-cp27m
+
+ ```bash
+ export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
+ export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
+ export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python
+ -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7
+ -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so"
+ ```
+
+- cp27-cp27mu
+
+ ```bash
+ export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:}
+ export PATH=/opt/python/cp27-cp27mu/bin/:${PATH}
+ export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python
+ -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7
+ -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so"
+ ```
+
+And then add the `PYTHON_FLAGS` as your cmake flags:
+
+```bash
+cmake ..
+ ${PYTHON_FLAGS} \
+ -DWITH_GPU=OFF \
+ ...
+```
+
+You can find more details about cmake flags at [here](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html#appendix-build-options)