diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bbe829ac16180088bfa37df66e23f19b021ea3..030bd19b3fd2f561a847bbc4613e5d2030812a92 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
-find_package(Sphinx)
if(NOT CMAKE_CROSSCOMPILING)
find_package(CUDA QUIET)
endif(NOT CMAKE_CROSSCOMPILING)
@@ -226,5 +225,7 @@ if(WITH_PYTHON)
endif()
if(WITH_DOC)
+ find_package(Sphinx REQUIRED)
+ find_python_module(recommonmark REQUIRED)
add_subdirectory(doc)
endif()
diff --git a/doc/v2/build_and_install/build_from_source_cn.rst b/doc/v2/build_and_install/build_from_source_cn.rst
index 115b92a33888abf1e1be400e1abbb58b632a2976..f846928954dd3a05e11054ce2ff2ff839fbefd4b 100644
--- a/doc/v2/build_and_install/build_from_source_cn.rst
+++ b/doc/v2/build_and_install/build_from_source_cn.rst
@@ -19,8 +19,9 @@
----------------
PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像
-可以在 `这里 `_ 找到。或者
-参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
+可以在 `这里 `_ 找到,您也可以
+在 `这里 `_ 找到 paddle_manylinux_devel
+镜像的编译以及使用方法。或者参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。
diff --git a/doc/v2/build_and_install/build_from_source_en.rst b/doc/v2/build_and_install/build_from_source_en.rst
index 8fef9e7347e8d924026999bfda985381750c6b51..d1b5b88dff81d4c5cee3dd13a7dccbc333ab6a17 100644
--- a/doc/v2/build_and_install/build_from_source_en.rst
+++ b/doc/v2/build_and_install/build_from_source_en.rst
@@ -22,6 +22,8 @@ How To Build
You need to use Docker to build PaddlePaddle
to avoid installing dependencies by yourself. We have several pre-built
Docker images `here `_ ,
+you can also find how to build and use paddle_manylinux_devel Docker image from
+`here `_
Or you can build your own image from source as the optional step below:
.. code-block:: bash
diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h
index d373c48b1a75c5f75c7520b56f230bc2c146b174..a4eb6f706edab9479cbce436311eb96da8845646 100644
--- a/paddle/fluid/framework/operator.h
+++ b/paddle/fluid/framework/operator.h
@@ -192,6 +192,10 @@ class ExecutionContext {
return op_.Attr(name);
}
+ bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
+
+ bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
+
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 392b13d3dc75bf8cf3e593f7bcd36ee006b5cbb9..50c3468d556bfe05d6c41906cf35cb671f711b1e 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -57,7 +57,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector &local_scopes,
- const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy)
+ const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
+ size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
@@ -79,7 +80,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
- member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
+ auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
+ ncclUniqueId *nccl_id = nullptr;
+ if (nccl_id_var != nullptr) {
+ nccl_id = nccl_id_var->GetMutable();
+ }
+ member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
+ member_->places_, nccl_id, num_trainers, trainer_id));
#endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index 121e74293c5874402a4184d758f408c747ee0fc5..5247e790649e76567f4527d54499d6e95dac5c27 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -44,7 +44,8 @@ class ParallelExecutor {
const std::string &loss_var_name, Scope *scope,
const std::vector &local_scopes,
const ExecutionStrategy &exec_strategy,
- const BuildStrategy &build_strategy);
+ const BuildStrategy &build_strategy,
+ size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor();
diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt
index de7becae4d25d48111fea8d2123bc85aef70230a..47929ef7490e5edb246625cb0b3ba507039df27a 100644
--- a/paddle/fluid/inference/analysis/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/CMakeLists.txt
@@ -1 +1,2 @@
-cc_library(dot SRCS dot.cc)
+cc_library(analysis SRCS dot.cc node.cc node.h)
+cc_test(test_node SRCS node_tester.cc DEPS analysis)
diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h
new file mode 100644
index 0000000000000000000000000000000000000000..585c9923291e5f9cb6e50dbc4bcd28c374191048
--- /dev/null
+++ b/paddle/fluid/inference/analysis/device.h
@@ -0,0 +1,24 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#pragma once
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+enum class Device { CPU, GPU };
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h
index 3359987874f2d74d7e4646baa38790431c4b28fd..4bf1840fdda8508b52d7274a338c5b1c95baf354 100644
--- a/paddle/fluid/inference/analysis/dot.h
+++ b/paddle/fluid/inference/analysis/dot.h
@@ -21,6 +21,7 @@
#include
#include
+#include
#include
#include
diff --git a/paddle/fluid/inference/analysis/dot_tester.cc b/paddle/fluid/inference/analysis/dot_tester.cc
new file mode 100644
index 0000000000000000000000000000000000000000..56ceb9bd5d6f41a601d66f6124fb7b4099c9337e
--- /dev/null
+++ b/paddle/fluid/inference/analysis/dot_tester.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/analysis/dot.h"
+
+#include
+#include
+#include "paddle/fluid/inference/analysis/data_flow_graph.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class DotTester : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ std::vector attrs({{"title", "hello"}});
+ dot.reset(new Dot(attrs));
+ dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")});
+ dot->AddNode("b", {});
+ dot->AddNode("c", {});
+ dot->AddEdge("a", "b", {});
+ dot->AddEdge("b", "c", {});
+ dot->AddEdge("a", "c", {});
+ }
+
+ std::unique_ptr dot;
+};
+
+TEST_F(DotTester, Build) {
+ auto codes = dot->Build();
+ // Output the DOT language code, the generated codes are too long to compare
+ // the string.
+ //
+ // The output is
+ //
+ // digraph G {
+ // title="hello"
+ // node_1
+ // node_2
+ // node_0[label="a" shape="box" color="blue"]
+ // node_0->node_1
+ // node_1->node_2
+ // node_0->node_2
+ // } // end G
+ LOG(INFO) << '\n' << codes;
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h
new file mode 100644
index 0000000000000000000000000000000000000000..b2d06c5d63ff139186710cd963e07b4ba245f9f3
--- /dev/null
+++ b/paddle/fluid/inference/analysis/helper.h
@@ -0,0 +1,74 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+template
+class iterator_range {
+ IteratorT begin_, end_;
+
+ public:
+ template
+ explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
+
+ iterator_range(const IteratorT &begin, const IteratorT &end)
+ : begin_(begin), end_(end) {}
+
+ const IteratorT &begin() const { return begin_; }
+ const IteratorT &end() const { return end_; }
+};
+
+/*
+ * An registry helper class, with its records keeps the order they registers.
+ */
+template
+class OrderedRegistry {
+ public:
+ T *Register(const std::string &name, T *x) {
+ PADDLE_ENFORCE(!dic_.count(name));
+ dic_[name] = data_.size();
+ data_.emplace_back(std::unique_ptr(x));
+ return data_.back().get();
+ }
+
+ T *Lookup(const std::string &name) {
+ auto it = dic_.find(name);
+ if (it == dic_.end()) return nullptr;
+ return data_[it->second].get();
+ }
+
+ protected:
+ std::unordered_map dic_;
+ std::vector> data_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
+
+#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
+ \
+ type__(const type__ &) = delete; \
+ \
+ void operator=(const type__ &) = delete;
diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc
new file mode 100644
index 0000000000000000000000000000000000000000..fe060526080b1ee01aa98f2ff06fb2191eddf9da
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+#include "glog/logging.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+std::vector Value::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "box"),
+ Dot::Attr("fillcolor", "red")});
+}
+
+std::vector Function::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "diamond"),
+ Dot::Attr("fillcolor", "yellow")});
+}
+
+Node *NodeMap::Create(Node::Type type) {
+ switch (type) {
+ case Node::Type::kFunction:
+ nodes_.emplace_back(new Function);
+ break;
+ case Node::Type::kValue:
+ nodes_.emplace_back(new Value);
+ break;
+ default:
+ PADDLE_THROW("Not supported node type.");
+ }
+ nodes_.back()->id_ = size() - 1;
+ return nodes_.back().get();
+}
+
+Node *NodeMap::GetMutable(size_t id) {
+ PADDLE_ENFORCE_GT(size(), id);
+ return nodes_[id].get();
+}
+
+const Node &NodeMap::Get(size_t id) const {
+ PADDLE_ENFORCE_GT(size(), id);
+ return *nodes_[id].get();
+}
+
+void NodeMap::Delete(size_t id) {
+ PADDLE_ENFORCE_LT(id, size());
+ nodes_[id]->SetDeleted();
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h
new file mode 100644
index 0000000000000000000000000000000000000000..07cb7669f98237399c4165947a03c67ce2a86aa8
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.h
@@ -0,0 +1,235 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+/*
+ * This file defines the Node class and its subclasses. A Node is the basis
+ * analysis element in a computation graph.
+ * There are basically two kinds of nodes, the function node and value node.
+ */
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/inference/analysis/device.h"
+#include "paddle/fluid/inference/analysis/dot.h"
+#include "paddle/fluid/inference/analysis/helper.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class NodeMap;
+
+/*
+ * Node Representation.
+ *
+ * This is a very important class for analysis. It is the base class of all
+ * nodes computed by a program that may be used as operands to other nodes.
+ * Node is the super class of other important classes such as Function and
+ * Value, some nodes can have a name.
+ */
+class Node {
+ public:
+ // Node type. NOTE the new node types should add here.
+ enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock };
+
+ Node() = default;
+
+ struct Attr;
+
+ // Cast to a subclass type, Function for example.
+ template
+ Subclass &As() {
+ return *dynamic_cast(this);
+ }
+
+ // Formatted representation of this Node.
+ virtual std::string repr() const {
+ return name() + "(" + std::to_string(id()) + ")";
+ }
+
+ // DOT node representation. One Node type can customize its own node
+ // representation.
+ virtual std::vector dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled")});
+ }
+
+ // Get an additional attribute and convert it to T data type. NOTE this will
+ // silently create a new attribute if not exists.
+ Attr &attr(const std::string &name) { return attrs_[name]; }
+
+ int id() const { return id_; }
+
+ bool deleted() const { return deleted_; }
+ void SetDeleted() { deleted_ = true; }
+
+ void SetName(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
+
+ void SetType(Type type) { type_ = type; }
+ Type type() const { return type_; }
+
+ void *extra_info() const { return extra_info_; }
+ void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
+
+ // Input links.
+ std::vector inlinks;
+ // Output links.
+ std::vector outlinks;
+
+ // A helper class to maintain the status from Pass.
+ // TODO(superjomn) add a checker here to ensure the T is primary.
+ struct Attr {
+ // NOTE T should be a primary type or a struct combined by several primary
+ // types.
+ // NOTE the STL containers should not use here.
+ // Some usages
+ // Attr attr;
+ // T data;
+ // attr.data.assign((char*)data, sizeof(data));
+
+ bool &Bool() { return As(); }
+ float &Float() { return As(); }
+ int32_t &Int32() { return As(); }
+ int64_t &Int64() { return As(); }
+
+ private:
+ template
+ T &As() {
+ // init storage in the first usage.
+ if (data_.empty()) {
+ VLOG(4) << "resize data to " << sizeof(T);
+ type_hash_ = typeid(T).hash_code();
+ data_.resize(sizeof(T));
+ }
+ PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), "type not matched");
+ PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
+ return *reinterpret_cast(&data_[0]);
+ }
+
+ private:
+ std::string data_;
+ size_t type_hash_{std::numeric_limits::max()};
+ };
+
+ virtual ~Node() {}
+
+ friend class NodeMap;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Node);
+
+ protected:
+ // The id number not the name is a node's unique identifier in the computation
+ // graph.
+ int id_{-1};
+ std::string name_;
+ Type type_{Type::kNone};
+ // Mark this node is deleted by some pass.
+ bool deleted_{false};
+
+ void *extra_info_;
+
+ mutable std::unordered_map attrs_;
+};
+
+class Function;
+/*
+ * Value represents a value node, it has some attributes including dims, data
+ * type and so on.
+ */
+class Value : public Node {
+ public:
+ enum class DataType { kInt32, kInt64, kFloat32, kFloat64 };
+ using Dims = std::vector;
+
+ void SetDataType(DataType data_type) { data_type_ = data_type; }
+ DataType data_type() const { return data_type_; }
+
+ void SetDims(const Dims &dims) { dims_ = dims; }
+ const Dims &dims() const { return dims_; }
+
+ Device device() const { return device_; }
+ void SetDevice(Device device) { device_ = device; }
+
+ std::vector dot_attrs() const override;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Value);
+
+ protected:
+ Value() { SetType(Node::Type::kValue); }
+ friend class NodeMap;
+
+ private:
+ DataType data_type_;
+ Dims dims_;
+ Device device_;
+};
+
+/*
+ * Function represents any kind of executable concepts that takes several Values
+ * as input, and outputs several Values.
+ */
+class Function : public Node {
+ public:
+ std::vector dot_attrs() const override;
+
+ // Get the operator's type from Desc.
+ const std::string &func_type() const { return func_type_; }
+ // Set the operator's type.
+ void SetFuncType(const std::string &func_type) { func_type_ = func_type; }
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Function);
+
+ protected:
+ std::string func_type_;
+ Function() { SetType(Node::Type::kFunction); }
+ friend class NodeMap;
+};
+
+/*
+ * FunctionBlock is a Node that contains a sub-graph multiple Node.
+ */
+struct FunctionBlock : public Node {
+ std::string repr() const override { return "block-" + std::to_string(id()); }
+ std::vector subgraph;
+};
+
+class NodeMap {
+ public:
+ // Create a new node with type.
+ Node *Create(Node::Type type);
+
+ // Get a node by its id.
+ Node *GetMutable(size_t id);
+
+ const Node &Get(size_t id) const;
+
+ void Delete(size_t id);
+
+ const std::vector> &nodes() { return nodes_; }
+
+ size_t size() const { return nodes_.size(); }
+
+ private:
+ std::vector> nodes_;
+ std::unordered_map map_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node_tester.cc b/paddle/fluid/inference/analysis/node_tester.cc
new file mode 100644
index 0000000000000000000000000000000000000000..47fea0fdff808c930ca73edb25f5b16fef397e9a
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node_tester.cc
@@ -0,0 +1,34 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+
+#include
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+TEST(Node, Attr) {
+ // Node is an abstract class, use Value instead for they share the same Attr
+ // logic.
+ NodeMap nodes;
+ auto* node = nodes.Create(Node::Type::kValue);
+ node->attr("v0").Int32() = 2008;
+ ASSERT_EQ(node->attr("v0").Int32(), 2008);
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h
index de0375551e16ec53b90414c7446234fda98bf706..ce2b8161715a3fa2278ce950dbac82c6d0042bef 100644
--- a/paddle/fluid/inference/engine.h
+++ b/paddle/fluid/inference/engine.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
@@ -58,8 +59,8 @@ class EngineBase {
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
- int max_size; // buffer allocated space.
- int size; // data size.
+ size_t max_size; // buffer allocated space.
+ size_t size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt
index 677b3e04af8e7f5662a15fb32e3b03f45d262733..b52d083f280e5e7713600a7b748dedd37aca0a1e 100644
--- a/paddle/fluid/inference/tensorrt/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt
@@ -1,5 +1,4 @@
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
-
add_subdirectory(convert)
diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
index 5178c54c08400125d190078dac6c52d021f8488b..4fb4511d99179e4ea14cde66feb13bc9e114581a 100644
--- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
@@ -1,4 +1,4 @@
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
-nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
+nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
index 543784289cfc51048057e467d36fdd1f334eb903..6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1 100644
--- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
@@ -21,15 +21,18 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
+ // Here the two nullptr looks strange, that's because the
+ // framework::OpDesc's constructor is strange.
+ framework::OpDesc op_desc(op, nullptr, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu";
const nvinfer1::ITensor* input_tensor =
- engine_->GetITensor(op.Input("X")[0]);
+ engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast(input_tensor),
nvinfer1::ActivationType::kRELU);
- engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0));
+ engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
index 431500b90e144e2a30fe705b72e93452f806ca65..209936c3bafb0d31546856dc36c1b48053a0634b 100644
--- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
index 32e8631fde3f748669d2008b4a060455a37e154e..854f434d93e81237dc85c5df62debcf3b3824b78 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
@@ -23,26 +23,42 @@ namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
-class DefaultInputConverter : public EngineInputConverter {
+class DefaultIOConverter : public EngineIOConverter {
public:
- DefaultInputConverter() {}
+ DefaultIOConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
- PADDLE_ENFORCE_LE(in.memory_size(), max_size);
+ PADDLE_ENFORCE(stream_ != nullptr);
const auto& place = in.place();
+ size_t size = in.memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
if (is_cpu_place(place)) {
- PADDLE_ENFORCE(stream_ != nullptr);
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToDevice, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToHost, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyDeviceToDevice, *stream_));
+ } else {
+ PADDLE_THROW("Unknown device for converter");
+ }
+ cudaStreamSynchronize(*stream_);
+ }
+ // NOTE in is GPU memory.
+ virtual void operator()(const void* in, LoDTensor* out,
+ size_t max_size) override {
+ PADDLE_ENFORCE(in != nullptr);
+ PADDLE_ENFORCE(stream_ != nullptr);
+ const auto& place = out->place();
+ size_t size = out->memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
+ if (is_cpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToHost, *stream_));
+ } else if (is_gpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToDevice, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
@@ -50,7 +66,8 @@ class DefaultInputConverter : public EngineInputConverter {
}
};
-REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
+// fluid LodTensor <-> tensorrt ITensor
+REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter);
} // namespace tensorrt
} // namespace inference
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h
index 8972dae92be2c2d261a13c48d98e675f64e51d31..71c48e085d25d2bc6720d93735f661f9e3af7b40 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
@@ -25,43 +26,57 @@ namespace tensorrt {
using framework::LoDTensor;
/*
- * Convert Input from Fluid to an Engine.
- * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
- * most cases just need to copy the data.
+ * Convert Input from Fluid to TensorRT Engine.
+ * Convert Output from TensorRT Engine to Fluid.
+ *
+ * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row
+ * major,
+ * so in the default case just need to copy the data.
*/
-class EngineInputConverter {
+class EngineIOConverter {
public:
- EngineInputConverter() {}
+ EngineIOConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
+ virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
- static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
- size_t max_size, cudaStream_t* stream) {
+ static void ConvertInput(const std::string& op_type, const LoDTensor& in,
+ void* out, size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
- auto* converter = Registry::Lookup(
- in_op_type, "default" /* default_type */);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
- virtual ~EngineInputConverter() {}
+ static void ConvertOutput(const std::string& op_type, const void* in,
+ LoDTensor* out, size_t max_size,
+ cudaStream_t* stream) {
+ PADDLE_ENFORCE(stream != nullptr);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
+ PADDLE_ENFORCE_NOT_NULL(converter);
+ converter->SetStream(stream);
+ (*converter)(in, out, max_size);
+ }
+
+ virtual ~EngineIOConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
+#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \
+ struct trt_io_##op_type__##_converter { \
+ trt_io_##op_type__##_converter() { \
+ Registry::Register(#op_type__); \
+ } \
+ }; \
+ trt_io_##op_type__##_converter trt_io_##op_type__##_converter__;
+
} // namespace tensorrt
} // namespace inference
} // namespace paddle
-
-#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
- struct trt_input_##in_op_type__##_converter { \
- trt_input_##in_op_type__##_converter() { \
- ::paddle::inference::Registry::Register< \
- Converter__>(#in_op_type__); \
- } \
- }; \
- trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
index f9834ab156c9dcc11f4e89075b7bf5457cf00268..3ca58b139bd3af1947ae7f063060e11d2ea7d577 100644
--- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h
index 77c788550b2c7df1f483b926661789b2a54d8fff..abc9ebf472498f6653d5bb1113ae2f3ce7e5a923 100644
--- a/paddle/fluid/inference/tensorrt/convert/op_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h
@@ -31,10 +31,10 @@ namespace tensorrt {
class OpConverter {
public:
OpConverter() {}
- virtual void operator()(const framework::OpDesc& op) {}
+ virtual void operator()(const framework::proto::OpDesc& op) {}
- void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
- std::string type = op.Type();
+ void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
+ std::string type = op.type();
auto* it = Registry::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
@@ -42,14 +42,16 @@ class OpConverter {
}
// convert fluid op to tensorrt layer
- void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
+ void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine);
}
// convert fluid block to tensorrt network
- void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
- for (auto op : block.AllOps()) {
- OpConverter::Run(*op, engine);
+ void ConvertBlock(const framework::proto::BlockDesc& block,
+ TensorRTEngine* engine) {
+ for (size_t i = 0; i < block.ops_size(); i++) {
+ const auto& op = block.ops(i);
+ OpConverter::Run(op, engine);
}
}
diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
index 23e3435c21725328d3765fae0d158a83ac21478b..ec33f97c8240dfc09a203d68599bffe78a4abb12 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
@@ -26,7 +27,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
-void Compare(float input, float expect) {
+void Compare(const std::string op_type, float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
@@ -35,6 +36,7 @@ void Compare(float input, float expect) {
auto x_var = scope.Var("X");
auto x_tensor = x_var->GetMutable();
x_tensor->Resize({1, 1});
+ x_tensor->mutable_data(place);
std::vector init;
init.push_back(input);
framework::TensorFromVector(init, ctx, x_tensor);
@@ -45,14 +47,15 @@ void Compare(float input, float expect) {
out_tensor->mutable_data(place);
framework::OpDesc op_desc;
- op_desc.SetType("relu");
+ op_desc.SetType(op_type);
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
- auto relu_op = framework::OpRegistry::CreateOp(op_desc);
+ auto op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
- relu_op->Run(scope, place);
+ op->Run(scope, place);
+ // get fluid output
std::vector out1;
framework::TensorToVector(*out_tensor, ctx, &out1);
@@ -63,21 +66,28 @@ void Compare(float input, float expect) {
engine->InitNetwork();
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
-
+ // convert op
OpConverter op_converter;
- op_converter.ConvertOp(op_desc, engine);
+ op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
- engine->SetInputFromCPU("X", &input, 1 * sizeof(float));
- // run tensorrt op
+ // convert LoDTensor to ITensor
+ size_t size = x_tensor->memory_size();
+ EngineIOConverter::ConvertInput(op_type, *x_tensor,
+ engine->buffer("X").buffer, size, &stream);
+ // run tensorrt Outp
engine->Execute(1);
-
- float out2;
- engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float));
-
- ASSERT_EQ(out1[0], out2);
+ // convert ITensor to LoDTensor
+ EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer,
+ out_tensor, size, &stream);
+ // get tensorrt output
+ std::vector out2;
+ framework::TensorToVector(*out_tensor, ctx, &out2);
+
+ // compare
+ ASSERT_EQ(out1[0], out2[0]);
ASSERT_EQ(out1[0], expect);
delete engine;
@@ -85,8 +95,8 @@ void Compare(float input, float expect) {
}
TEST(OpConverter, ConvertRelu) {
- Compare(1, 1); // relu(1) = 1
- Compare(-5, 0); // relu(-5) = 0
+ Compare("relu", 1, 1); // relu(1) = 1
+ Compare("relu", -5, 0); // relu(-5) = 0
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
index afcc516e6b76d58e37ce0e60746704cf3933fac7..8f91309a0a00d5131268f026c319e25ba3cb964a 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
@@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
-#include
-
namespace paddle {
namespace inference {
namespace tensorrt {
-class EngineInputConverterTester : public ::testing::Test {
- public:
- void SetUp() override { tensor.Resize({10, 10}); }
+void IOConverterTester(const platform::DeviceContext& ctx) {
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
- framework::LoDTensor tensor;
-};
+ // init fluid in_tensor
+ framework::LoDTensor in_tensor;
+ in_tensor.Resize({10, 10});
+ auto place = ctx.GetPlace();
+ in_tensor.mutable_data(place);
+ std::vector init;
+ for (int64_t i = 0; i < 10 * 10; ++i) {
+ init.push_back(i);
+ }
+ framework::TensorFromVector(init, ctx, &in_tensor);
-TEST_F(EngineInputConverterTester, DefaultCPU) {
+ // init tensorrt buffer
void* buffer;
- tensor.mutable_data(platform::CPUPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+ size_t size = in_tensor.memory_size();
+ ASSERT_EQ(cudaMalloc(&buffer, size), 0);
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+ // convert fluid in_tensor to tensorrt buffer
+ EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream);
+
+ // convert tensorrt buffer to fluid out_tensor
+ framework::LoDTensor out_tensor;
+ out_tensor.Resize({10, 10});
+ out_tensor.mutable_data(place);
+ EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream);
+
+ // compare in_tensor and out_tensor
+ std::vector result;
+ framework::TensorToVector(out_tensor, ctx, &result);
+ EXPECT_EQ(init.size(), result.size());
+ for (size_t i = 0; i < init.size(); i++) {
+ EXPECT_EQ(init[i], result[i]);
+ }
+ cudaStreamDestroy(stream);
}
-TEST_F(EngineInputConverterTester, DefaultGPU) {
- void* buffer;
- tensor.mutable_data(platform::CUDAPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+TEST(EngineIOConverterTester, DefaultCPU) {
+ platform::CPUPlace place;
+ platform::CPUDeviceContext ctx(place);
+ IOConverterTester(ctx);
+}
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+TEST(EngineIOConverterTester, DefaultGPU) {
+ platform::CUDAPlace place;
+ platform::CUDADeviceContext ctx(place);
+ IOConverterTester(ctx);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
index aa5fb726f1129eda65a6f39791330b795aad660d..8d66543eb7637c5a8ae670b89ef5996954ba2e7b 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d");
OpConverter converter;
- converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/);
+ converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
index c4fd1e298b0daea85db2a407d04ad2d7bcdee0f0..60c761c5281e2f535aab0200c93fb738addcdb87 100644
--- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
+++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
@@ -16,7 +16,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
-DEFINE_string(data_set, "cifar10", "Data set to test");
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
@@ -35,19 +34,19 @@ TEST(inference, image_classification) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+ const bool is_combined = false;
+ std::vector> feed_target_shapes =
+ GetFeedTargetShapes(dirname, is_combined);
+
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
- if (FLAGS_data_set == "cifar10") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32},
- static_cast(0), static_cast(1));
- } else if (FLAGS_data_set == "imagenet") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224},
- static_cast(0), static_cast(1));
- } else {
- LOG(FATAL) << "Only cifar10 or imagenet is supported.";
- }
-
+ feed_target_shapes[0][0] = FLAGS_batch_size;
+ paddle::framework::DDim input_dims =
+ paddle::framework::make_ddim(feed_target_shapes[0]);
+ LOG(INFO) << input_dims;
+ SetupTensor(&input, input_dims, static_cast(0),
+ static_cast(1));
std::vector cpu_feeds;
cpu_feeds.push_back(&input);
@@ -60,7 +59,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
LOG(INFO) << output1.dims();
}
@@ -73,7 +72,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- GPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined);
LOG(INFO) << output2.dims();
if (!FLAGS_skip_cpu) {
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index af2a7a5620487a10c1df6152fc4e4bf67b150752..b02e5c99f00eaf03c3753e43575cbc67e834774e 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
+std::unique_ptr InitProgram(
+ paddle::framework::Executor* executor, paddle::framework::Scope* scope,
+ const std::string& dirname, const bool is_combined = false) {
+ std::unique_ptr inference_program;
+ if (is_combined) {
+ // All parameters are saved in a single file.
+ // Hard-coding the file names of program and parameters in unittest.
+ // The file names should be consistent with that used in Python API
+ // `fluid.io.save_inference_model`.
+ std::string prog_filename = "__model_combined__";
+ std::string param_filename = "__params_combined__";
+ inference_program =
+ paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
+ dirname + "/" + param_filename);
+ } else {
+ // Parameters are saved in separate files sited in the specified
+ // `dirname`.
+ inference_program = paddle::inference::Load(executor, scope, dirname);
+ }
+ return inference_program;
+}
+
+std::vector> GetFeedTargetShapes(
+ const std::string& dirname, const bool is_combined = false) {
+ auto place = paddle::platform::CPUPlace();
+ auto executor = paddle::framework::Executor(place);
+ auto* scope = new paddle::framework::Scope();
+
+ auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
+ auto& global_block = inference_program->Block(0);
+
+ const std::vector& feed_target_names =
+ inference_program->GetFeedTargetNames();
+ std::vector> feed_target_shapes;
+ for (size_t i = 0; i < feed_target_names.size(); ++i) {
+ auto* var = global_block.FindVar(feed_target_names[i]);
+ std::vector var_shape = var->GetShape();
+ feed_target_shapes.push_back(var_shape);
+ }
+
+ delete scope;
+ return feed_target_shapes;
+}
+
template
void TestInference(const std::string& dirname,
const std::vector& cpu_feeds,
@@ -124,22 +168,7 @@ void TestInference(const std::string& dirname,
paddle::platform::RecordEvent record_event(
"init_program",
paddle::platform::DeviceContextPool::Instance().Get(place));
-
- if (is_combined) {
- // All parameters are saved in a single file.
- // Hard-coding the file names of program and parameters in unittest.
- // The file names should be consistent with that used in Python API
- // `fluid.io.save_inference_model`.
- std::string prog_filename = "__model_combined__";
- std::string param_filename = "__params_combined__";
- inference_program = paddle::inference::Load(
- &executor, scope, dirname + "/" + prog_filename,
- dirname + "/" + param_filename);
- } else {
- // Parameters are saved in separate files sited in the specified
- // `dirname`.
- inference_program = paddle::inference::Load(&executor, scope, dirname);
- }
+ inference_program = InitProgram(&executor, scope, dirname, is_combined);
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index c14a2b7786f9f7c06d59479d3bbce9c5d542e495..d38a9ce58726a1d045d6905354b0b592166c0110 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -186,6 +186,11 @@ endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE)
+ if(WITH_GPU)
+ op_library(gen_nccl_id_op DEPS nccl_common)
+ else()
+ set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
+ endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
+ cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else()
- set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
+ set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
+ s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h
index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644
--- a/paddle/fluid/operators/detail/grpc_client.h
+++ b/paddle/fluid/operators/detail/grpc_client.h
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
- explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; }
+ explicit BaseProcessor(std::shared_ptr ch) {
+ context_ = nullptr;
+ }
virtual ~BaseProcessor() {}
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
- RequestSendCallBack response_call_back_ = NULL;
+ RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function
diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h
index 7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84..18f1bc53d0f561f412a5bbbe018bc3d427ac9ef9 100644
--- a/paddle/fluid/operators/detail/grpc_server.h
+++ b/paddle/fluid/operators/detail/grpc_server.h
@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
+ ~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto
index fffa9ae7a43ea5cd7b2bda6fbbf6ef9f7d23009d..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644
--- a/paddle/fluid/operators/detail/send_recv.proto
+++ b/paddle/fluid/operators/detail/send_recv.proto
@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
+ NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc
index 1a8a1af20fa446dbd537944409ef0ca1e3e9116f..07c43554bc6a0d71d688a5a5772d0ab3d2de319a 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.cc
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include
#include // NOLINT
@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else if (var->IsType()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
+#ifdef PADDLE_WITH_CUDA
+ } else if (var->IsType()) {
+ request.set_type(::sendrecv::NCCL_ID);
+#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
+// NCCLID is copied directly to the message, return bytebuffer
+// with only one slice if serializing NCCLID.
+#ifdef PADDLE_WITH_CUDA
+ if (var->IsType()) {
+ e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
+ NCCL_UNIQUE_ID_BYTES);
+ const ncclUniqueId& uid = var->Get();
+ e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
+
+ // for serialize NCCL_ID
+ ::grpc::Slice slices(e.size());
+ memcpy(const_cast(slices.begin()), e.data(), e.size());
+ ::grpc::ByteBuffer tmp(&slices, 1);
+ msg->Swap(&tmp);
+ return;
+ }
+#endif
+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc
index 99602a05d023f30c2eed8df25e7534fdc9ef2ced..462e303096e609c6797ca8cc16266ec3621623fc 100644
--- a/paddle/fluid/operators/detail/variable_response.cc
+++ b/paddle/fluid/operators/detail/variable_response.cc
@@ -17,6 +17,9 @@
#include
#include
#include
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
- meta_.type() == sendrecv::LOD_TENSOR) &&
+ meta_.type() == sendrecv::LOD_TENSOR ||
+ meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
"meta info should be got first!");
@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
return tag;
}
+ if (meta_.type() == sendrecv::NCCL_ID) {
+#ifdef PADDLE_WITH_CUDA
+ auto* var = scope_->FindVar(meta_.varname());
+ if (var != nullptr) {
+ ncclUniqueId* id = var->GetMutable();
+ if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
+ num_bytes)) {
+ return tag;
+ }
+ }
+ break;
+#else
+ PADDLE_THROW("Not compiled with CUDA!");
+#endif
+ }
+
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a5678f63466d368b3dd59380c18f9625cabd368b
--- /dev/null
+++ b/paddle/fluid/operators/gen_nccl_id_op.cc
@@ -0,0 +1,128 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/threadpool.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/detail/grpc_server.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+
+namespace paddle {
+namespace operators {
+
+class GenNCCLIdOp : public framework::OperatorBase {
+ public:
+ GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
+ const framework::VariableNameMap& outputs,
+ const framework::AttributeMap& attrs)
+ : OperatorBase(type, inputs, outputs, attrs) {}
+
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ // put nccl id in CPUPlace
+ auto& dev_ctx = *pool.Get(platform::CPUPlace());
+ int trainer_id = Attr("trainer_id");
+ framework::Scope& local_scope = scope.NewScope();
+
+ if (trainer_id == 0) {
+ GenerateAndSend(&local_scope, dev_ctx);
+ } else {
+ GetIdByServer(&local_scope, dev_ctx);
+ }
+ }
+
+ private:
+ void GenerateAndSend(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ auto var = scope->FindVar(NCCL_ID_VARNAME);
+ PADDLE_ENFORCE_NOT_NULL(var);
+ auto id = var->GetMutable();
+ PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
+
+ std::vector endpoint_list =
+ Attr>("endpoint_list");
+ detail::RPCClient client;
+ for (auto& ep : endpoint_list) {
+ VLOG(3) << "sending nccl id to " << ep;
+ client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
+ }
+ client.Wait();
+ VLOG(3) << "sending completed...";
+ }
+
+ void GetIdByServer(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ std::string endpoint = Attr("endpoint");
+ // NOTE: Can not use unique_ptr here because the default
+ // deleter will call GRPC Server's base class's dtor and
+ // that will cause a wired crash.
+ detail::AsyncGRPCServer rpc_service(endpoint, true);
+ framework::ProgramDesc empty_program;
+ framework::Executor executor(dev_ctx.GetPlace());
+ rpc_service.SetScope(scope);
+ rpc_service.SetDevCtx(&dev_ctx);
+ rpc_service.SetProgram(&empty_program);
+ rpc_service.SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
+ rpc_service.SetCond(0);
+ VLOG(3) << "start getting nccl id from trainer 0...";
+ auto recv = rpc_service.Get();
+ VLOG(3) << "got nccl id and stop server...";
+ rpc_service.ShutDown();
+ VLOG(3) << "rpc server stopped";
+ server_thread.join();
+ }
+};
+
+class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
+ AddComment(R"DOC(
+GenNCCLId operator
+
+For trainer 0: generate a new UniqueId and send it to all the other trainers.
+For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
+)DOC");
+ AddAttr("endpoint",
+ "(string), e.g. 127.0.0.1:6175 "
+ "current listen endpoint");
+ AddAttr>(
+ "endpoint_list",
+ "['trainer1_ip:port', 'trainer2_ip:port', ...] "
+ "list of trainer endpoints start from trainer 1")
+ .SetDefault({});
+ AddAttr("trainer_id",
+ "(int default 0) "
+ "The index of the trainer in distributed training.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h
index 0abda999a52bcbb94e6503692bd11aff26e849ba..62e6307ae9f4236a38c49daaf09fc05c54268159 100644
--- a/paddle/fluid/operators/math/sequence2batch.h
+++ b/paddle/fluid/operators/math/sequence2batch.h
@@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
- PADDLE_ENFORCE_GT(lods.size(), 2UL);
- PADDLE_ENFORCE_EQ(lods[1].size(),
- static_cast(lod_tensor.dims()[0]));
+ PADDLE_ENFORCE_GT(lods.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ lods[1].size(), static_cast(lod_tensor.dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_batch;
to_batch(context, lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
- auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
+ auto lod = lods[0];
+
std::vector seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
@@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod();
- PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
- PADDLE_ENFORCE_EQ(in_lod[1].size(),
- static_cast(lod_tensor->dims()[0]));
+ PADDLE_ENFORCE_GT(in_lod.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ in_lod[1].size(), static_cast(lod_tensor->dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false);
}
diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h
index ccd7063fe69e0f21b4d2a821bb70902b39c9b9de..3dd8c7c11eca241e747bfa129962032d882ce44c 100644
--- a/paddle/fluid/operators/reshape_op.h
+++ b/paddle/fluid/operators/reshape_op.h
@@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
if (unk_dim_idx != -1) {
- output_shape[unk_dim_idx] = -in_size / capacity;
- // in_size < 0 and is un-determinate in compile time, skip the check,
- // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
- // capacity = -24, in_size = -8, output_shape[0] = 0
- // the following check will fail.
if (in_size > 0) {
+ // in_size < 0 and is un-determinate in compile time, skip the check,
+ // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
+ // capacity = -24, in_size = -8, output_shape[0] = 0
+ // the following check will fail.
+ output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
+ } else {
+ output_shape[unk_dim_idx] = -1;
}
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
@@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output("Out");
auto *in = ctx.Input("X");
- auto *shape_tensor = ctx.Input("Shape");
+
+ auto *shape_tensor = ctx.HasInput("Shape")
+ ? ctx.Input("Shape")
+ : nullptr;
framework::DDim out_dims = out->dims();
diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc
new file mode 100644
index 0000000000000000000000000000000000000000..bbae1d54aa3524fd45cb8ab13c86df8d54b8e643
--- /dev/null
+++ b/paddle/fluid/operators/test_send_nccl_id.cc
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include // NOLINT
+
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/listen_and_serv_op.h"
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+#include "paddle/fluid/string/printf.h"
+
+USE_NO_KERNEL_OP(listen_and_serv);
+
+namespace f = paddle::framework;
+namespace p = paddle::platform;
+namespace m = paddle::operators::math;
+namespace detail = paddle::operators::detail;
+namespace string = paddle::string;
+
+std::unique_ptr rpc_service;
+
+void StartServer(std::atomic* initialized) {
+ f::Scope scope;
+ p::CPUPlace place;
+ scope.Var(NCCL_ID_VARNAME);
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
+
+ f::ProgramDesc empty_program;
+ f::Executor executor(dev_ctx.GetPlace());
+ rpc_service->SetScope(&scope);
+ rpc_service->SetDevCtx(&dev_ctx);
+ rpc_service->SetProgram(&empty_program);
+ rpc_service->SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
+ *initialized = true;
+ rpc_service->SetCond(0);
+ auto recv = rpc_service->Get();
+ LOG(INFO) << "got nccl id and stop server...";
+ rpc_service->ShutDown();
+ server_thread.join();
+}
+
+TEST(SendNcclId, Normal) {
+ std::atomic initialized{false};
+ std::thread server_thread(StartServer, &initialized);
+ while (!initialized) {
+ }
+ // wait server to start
+ // sleep(2);
+ rpc_service->WaitServerReady();
+
+ f::Scope scope;
+ p::CPUPlace place;
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ auto var = scope.Var(NCCL_ID_VARNAME);
+ // var->SetType(f::proto::VarType_Type_RAW);
+ auto id = var->GetMutable();
+ p::dynload::ncclGetUniqueId(id);
+
+ int port = rpc_service->GetSelectedPort();
+ std::string ep = string::Sprintf("127.0.0.1:%d", port);
+ detail::RPCClient client;
+
+ client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
+ client.Wait();
+ server_thread.join();
+ auto* ptr = rpc_service.release();
+ delete ptr;
+}
diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index ec1682a44e2d9ecdf5ff7e6969f84f79254f86a7..09367889a9517956ad01ad2847c31e2633cc643d 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -14,12 +14,15 @@
#pragma once
+#include
#include // NOLINT
#include
#include
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
+#define NCCL_ID_VARNAME "NCCLID"
+
namespace paddle {
namespace platform {
@@ -73,7 +76,9 @@ struct NCCLContextMap {
std::unordered_map contexts_;
std::vector order_;
- explicit NCCLContextMap(const std::vector &places) {
+ explicit NCCLContextMap(const std::vector &places,
+ ncclUniqueId *nccl_id = nullptr,
+ size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
@@ -85,18 +90,34 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
- if (places.size() > 1) {
- std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ if (places.size() <= 1) {
+ return;
+ }
+ std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ // if pass nccl_id here, can assume we are doing multi node training
+ if (nccl_id == nullptr) {
+ std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
+ comms.get(), static_cast(order_.size()), order_.data()));
+ } else {
+ PADDLE_ENFORCE_GT(num_trainers, 1);
+ // TODO(wuyi): need to ensure each node have same number of GPUs
{
- std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
- PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
- comms.get(), static_cast(order_.size()), order_.data()));
- }
- int i = 0;
- for (auto &dev_id : order_) {
- contexts_.at(dev_id).comm_ = comms[i++];
+ int nranks = num_trainers * order_.size();
+ NCCLGroupGuard gurad;
+ for (auto &gpu_id : order_) {
+ int rank = trainer_id * order_.size() + gpu_id;
+ VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
+ PADDLE_ENFORCE(cudaSetDevice(gpu_id));
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
+ comms.get() + gpu_id, nranks, *nccl_id, rank));
+ }
}
}
+ int i = 0;
+ for (auto &dev_id : order_) {
+ contexts_.at(dev_id).comm_ = comms[i++];
+ }
}
NCCLContextMap(const NCCLContextMap &other) = delete;
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index ee2c5b904499ae3bd94ffe96da7100cdf53865ca..50a1c07251b5bc4e7cc27de63f5457d3f94daef5 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -547,7 +547,8 @@ All parameter, weight, gradient are variables in Paddle.
const std::unordered_set &,
const std::unordered_set &, const ProgramDesc &,
const std::string &, Scope *, std::vector &,
- const ExecutionStrategy &, const BuildStrategy &>())
+ const ExecutionStrategy &, const BuildStrategy &, size_t,
+ size_t>())
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py
index 28e54f5492e7b04a1406e319cecf977d4a55725e..38c765938fe9d7b2103bfdd926874c485d0ff4dc 100644
--- a/python/paddle/fluid/framework.py
+++ b/python/paddle/fluid/framework.py
@@ -489,7 +489,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'listen_and_serv', 'parallel_do', 'save_combine',
'load_combine', 'ncclInit', 'channel_create', 'channel_close',
- 'channel_send', 'channel_recv', 'select'
+ 'channel_send', 'channel_recv', 'select', 'gen_nccl_id'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py
index deab761f72a3ff77ea978d8d81f2105b4463b4e0..3117dfe00c7a3df1035c439dc31b81e67781d0cc 100644
--- a/python/paddle/fluid/parallel_executor.py
+++ b/python/paddle/fluid/parallel_executor.py
@@ -33,6 +33,8 @@ class ParallelExecutor(object):
share_vars_from=None,
exec_strategy=None,
build_strategy=None,
+ num_trainers=1,
+ trainer_id=0,
**kwargs):
"""
ParallelExecutor can run program in parallel.
@@ -44,14 +46,11 @@ class ParallelExecutor(object):
if not provided, then default_main_program will be used.
share_vars_from(ParallelExecutor, default None): If provied,
it will share variables from the specified ParallelExecutor.
- use_default_grad_scale(bool, default True): If set True, a default
- scale value equal to `1./device_count` would be multiplied to
- gradients of each device and scaled gradients would be
- aggregated. Otherwise, a customized scale value should be fed
- to the network.
- balance_parameter_opt_between_cards(bool, default True): Whether
- updating different gradients on different cards. Currently, it
- is not recommended.
+ num_trainers(int, default 1): If greater than 1, NCCL will be
+ initialized with multpile rank of nodes, each node should have
+ same number of GPUs. Distributed training will be enabled then.
+ trainer_id(int, default 0): Must use together with num_trainers.
+ trainer_id is the "rank" of current node starts from 0.
Returns:
A ParallelExecutor object.
@@ -151,9 +150,9 @@ class ParallelExecutor(object):
p.name for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
- set(self.persistable_vars), main.desc, loss_name if loss_name else
- '', scope, local_scopes, exec_strategy, build_strategy)
-
+ set(self.persistable_vars), main.desc, loss_name
+ if loss_name else '', scope, local_scopes, exec_strategy,
+ build_strategy, num_trainers, trainer_id)
self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None):
diff --git a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
index 9ab00325a2eef3bbc79757ad1a3e6f8511c49552..c2a15bdb3b17b65fe861dd429f548074c13e2f09 100644
--- a/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
+++ b/python/paddle/fluid/tests/book/high-level-api/CMakeLists.txt
@@ -6,4 +6,5 @@ foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
+add_subdirectory(fit_a_line)
add_subdirectory(recognize_digits)
diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..673c965b662a022739f8d489c331f4de9455a926
--- /dev/null
+++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/CMakeLists.txt
@@ -0,0 +1,7 @@
+file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
+string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
+
+# default test
+foreach(src ${TEST_OPS})
+ py_test(${src} SRCS ${src}.py)
+endforeach()
diff --git a/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c9bbb52d769282460c571ebc51d5eff18de3114
--- /dev/null
+++ b/python/paddle/fluid/tests/book/high-level-api/fit_a_line/test_fit_a_line.py
@@ -0,0 +1,137 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.fluid as fluid
+import contextlib
+import numpy
+import unittest
+
+# train reader
+BATCH_SIZE = 20
+
+train_reader = paddle.batch(
+ paddle.reader.shuffle(
+ paddle.dataset.uci_housing.train(), buf_size=500),
+ batch_size=BATCH_SIZE)
+
+test_reader = paddle.batch(
+ paddle.reader.shuffle(
+ paddle.dataset.uci_housing.test(), buf_size=500),
+ batch_size=BATCH_SIZE)
+
+
+def inference_program():
+ x = fluid.layers.data(name='x', shape=[13], dtype='float32')
+ y_predict = fluid.layers.fc(input=x, size=1, act=None)
+ return y_predict
+
+
+def linear():
+ y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+ y_predict = inference_program()
+
+ loss = fluid.layers.square_error_cost(input=y_predict, label=y)
+ avg_loss = fluid.layers.mean(loss)
+
+ return avg_loss
+
+
+def train(use_cuda, save_dirname):
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+
+ trainer = fluid.Trainer(
+ train_func=linear,
+ infer_func=inference_program,
+ place=place,
+ optimizer=fluid.optimizer.SGD(learning_rate=0.001))
+
+ def event_handler(event):
+ if isinstance(event, fluid.EndEpochEvent):
+ test_metrics = trainer.test(
+ reader=test_reader, feed_order=['x', 'y'])
+ print test_metrics
+ '''
+
+ ...
+ ['25.768919467926025']
+ ['15.343549569447836']
+ ...
+
+ '''
+ if float(test_metrics[0]) < 20.0:
+ if save_dirname is not None:
+ # NOT clear yet
+ # fluid.io.save_inference_model(save_dirname, ['x'], [y_predict])
+ # trainer.save_params(save_dirname)
+ # https://github.com/PaddlePaddle/Paddle/pull/10445
+ trainer.save_inference_model(save_dirname)
+ return
+
+ trainer.train(
+ reader=train_reader,
+ num_epochs=100,
+ event_handler=event_handler,
+ feed_order=['x', 'y'])
+
+
+# infer
+def infer(use_cuda, save_dirname=None):
+ if save_dirname is None:
+ return
+
+ place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+ inferencer = fluid.Inferencer(param_path=save_dirname, place=place)
+
+ batch_size = 10
+ tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
+
+ results = inferencer.infer({'x': tensor_x})
+ print("infer results: ", results[0])
+
+
+def main(use_cuda):
+ if use_cuda and not fluid.core.is_compiled_with_cuda():
+ return
+
+ # Directory for saving the trained model
+ save_dirname = "fit_a_line.inference.model"
+
+ train(use_cuda, save_dirname)
+ infer(use_cuda, save_dirname)
+
+
+class TestFitALine(unittest.TestCase):
+ def test_cpu(self):
+ with self.program_scope_guard():
+ with fluid.unique_name.guard():
+ main(use_cuda=False)
+
+ def test_cuda(self):
+ with self.program_scope_guard():
+ with fluid.unique_name.guard():
+ main(use_cuda=True)
+
+ @contextlib.contextmanager
+ def program_scope_guard(self):
+ prog = fluid.Program()
+ startup_prog = fluid.Program()
+ scope = fluid.core.Scope()
+ with fluid.scope_guard(scope):
+ with fluid.program_guard(prog, startup_prog):
+ yield
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/fluid/tests/book/notest_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py
index 241778e303036d068dc0a40e4574a02eb97ad134..792ed7368d646cd9dff9255eb402b6a9b84f69a6 100644
--- a/python/paddle/fluid/tests/book/notest_understand_sentiment.py
+++ b/python/paddle/fluid/tests/book/notest_understand_sentiment.py
@@ -170,7 +170,7 @@ def train(word_dict,
assert save_dirname is None
adagrad = fluid.optimizer.Adagrad(learning_rate=0.002)
- optimize_ops, params_grads = adagrad.minimize(cost)
+ adagrad.minimize(cost)
train_data = paddle.batch(
paddle.reader.shuffle(
diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py
index ecb34699af0dc14782601702ab8afedbca7e1bfd..b1a6b524d33cae97c8982ffb8f780b1b07761a09 100644
--- a/python/paddle/fluid/tests/book/test_fit_a_line.py
+++ b/python/paddle/fluid/tests/book/test_fit_a_line.py
@@ -33,7 +33,7 @@ def train(use_cuda, save_dirname, is_local):
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 20
diff --git a/python/paddle/fluid/tests/book/test_image_classification.py b/python/paddle/fluid/tests/book/test_image_classification.py
index dbcdb5766e7d20efdb12da0ea4c6f005d903849b..0f3a4c9242a81a3c1fb90268245715a8e59a207a 100644
--- a/python/paddle/fluid/tests/book/test_image_classification.py
+++ b/python/paddle/fluid/tests/book/test_image_classification.py
@@ -125,7 +125,7 @@ def train(net_type, use_cuda, save_dirname, is_local):
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
- optimize_ops, params_grads = optimizer.minimize(avg_cost)
+ optimizer.minimize(avg_cost)
BATCH_SIZE = 128
PASS_NUM = 1
diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py
index 0faba33032d5dfc0b751a5191e7b2ae0c1f172bf..09793760e5504c04ad4b0bfac5c5d7b7047cf85d 100644
--- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py
+++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py
@@ -175,7 +175,7 @@ def train(use_cuda, save_dirname=None, is_local=True):
decay_steps=100000,
decay_rate=0.5,
staircase=True))
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
# TODO(qiao)
# add dependency track and move this config before optimizer
diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py
index 46c6b9c29a265741a99655d5ac29244798f6fec2..e8a75f473f62df528b7f39bf5f9085076e005c25 100644
--- a/python/paddle/fluid/tests/book/test_machine_translation.py
+++ b/python/paddle/fluid/tests/book/test_machine_translation.py
@@ -185,7 +185,7 @@ def train_main(use_cuda, is_sparse, is_local=True):
learning_rate=1e-4,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.1))
- optimize_ops, params_grads = optimizer.minimize(avg_cost)
+ optimizer.minimize(avg_cost)
train_data = paddle.batch(
paddle.reader.shuffle(
diff --git a/python/paddle/fluid/tests/book/test_recognize_digits.py b/python/paddle/fluid/tests/book/test_recognize_digits.py
index c115aa4d7d6b514f9207543730e5e76cb0d2040c..578b1162fbd7e3a1b1c0cc934406818f2e07e019 100644
--- a/python/paddle/fluid/tests/book/test_recognize_digits.py
+++ b/python/paddle/fluid/tests/book/test_recognize_digits.py
@@ -95,7 +95,7 @@ def train(nn_type,
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
- optimize_ops, params_grads = optimizer.minimize(avg_loss)
+ optimizer.minimize(avg_loss)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
diff --git a/python/paddle/fluid/tests/book/test_recommender_system.py b/python/paddle/fluid/tests/book/test_recommender_system.py
index d022dedbff805d597b68b5a47f7931f2dd946615..7be924f762ddeb045dda890dbfdcd96a65449553 100644
--- a/python/paddle/fluid/tests/book/test_recommender_system.py
+++ b/python/paddle/fluid/tests/book/test_recommender_system.py
@@ -160,7 +160,7 @@ def train(use_cuda, save_dirname, is_local=True):
test_program = fluid.default_main_program().clone(for_test=True)
sgd_optimizer = SGDOptimizer(learning_rate=0.2)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
diff --git a/python/paddle/fluid/tests/book/test_word2vec.py b/python/paddle/fluid/tests/book/test_word2vec.py
index 6dec0f6857e86b4b9c1c67af934aa9bfdb1c3df7..30e1a5040cc92b02bbbf90dac97001812ec90134 100644
--- a/python/paddle/fluid/tests/book/test_word2vec.py
+++ b/python/paddle/fluid/tests/book/test_word2vec.py
@@ -101,7 +101,7 @@ def train(use_cuda, is_sparse, is_parallel, save_dirname, is_local=True):
avg_cost = fluid.layers.mean(pd())
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
- optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ sgd_optimizer.minimize(avg_cost)
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f8c4f3f0167632bb4a3d454ab026ba73a8f305
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
@@ -0,0 +1,113 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+import paddle.fluid.layers as layers
+from paddle.fluid.transpiler.distribute_transpiler import delete_ops
+import numpy
+
+
+class TestDistTranspiler(unittest.TestCase):
+ def setUp(self):
+ self.trainer_id = 0
+ self.trainers = 2
+ self.pservers = 2
+ self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
+ self.current_pserver_ep = "127.0.0.1:6174"
+
+ def net_conf(self):
+ x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
+
+ y_predict = fluid.layers.fc(input=x,
+ size=1000,
+ act=None,
+ param_attr=fluid.ParamAttr(name='fc_w'))
+
+ y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+
+ cost = fluid.layers.square_error_cost(input=y_predict, label=y)
+ avg_cost = fluid.layers.mean(cost)
+ sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
+
+ optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost)
+ return optimize_ops, params_grads
+
+ def test_transpiler(self):
+ trainer = self.get_trainer()
+ pserver, startup = self.get_pserver(self.current_pserver_ep)
+
+ self.assertEqual([op.type for op in trainer.global_block().ops],
+ self.get_expect_trainer_ops())
+
+ self.assertEqual(len(pserver.blocks), 3)
+ # block0: listen_and_serv
+ self.assertEqual([op.type for op in pserver.blocks[0].ops],
+ ["listen_and_serv"])
+ # block2: optimize pass
+ self.assertEqual([op.type for op in pserver.blocks[1].ops],
+ ["sum", "scale", "sgd"])
+
+ # confirm startup program
+
+ self.assertEqual([op.type for op in startup.global_block().ops], [
+ "fill_constant", "fill_constant", "uniform_random", "uniform_random"
+ ])
+
+ # the variable #fc_w will be split into two blocks
+ fc_w_var = startup.global_block().var("fc_w.block1")
+ self.assertEqual(fc_w_var.shape, (500, 1000))
+
+ def get_main_program(self):
+ main = fluid.Program()
+
+ with fluid.program_guard(main):
+ self.net_conf()
+
+ return main
+
+ def get_expect_trainer_ops(self):
+ trainer = fluid.Program()
+
+ with fluid.program_guard(trainer):
+ optimize_ops, params_grads = self.net_conf()
+
+ delete_ops(trainer.global_block(), optimize_ops)
+ return [op.type for op in trainer.global_block().ops
+ ] + ["split_byref", "send", "concat"]
+
+ def get_trainer(self):
+ return self._transpiler_instance().get_trainer_program()
+
+ def get_pserver(self, ep):
+ t = self._transpiler_instance()
+ pserver = t.get_pserver_program(ep)
+ startup = t.get_startup_program(ep, pserver)
+ return pserver, startup
+
+ def _transpiler_instance(self):
+ main = self.get_main_program()
+ t = fluid.DistributeTranspiler()
+ t.transpile(
+ self.trainer_id,
+ program=main,
+ pservers=self.pserver_eps,
+ trainers=self.trainers)
+ return t
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/paddle/fluid/transpiler/__init__.py b/python/paddle/fluid/transpiler/__init__.py
index 6d3c1b947f4acb1335b25e6eb0099d5d532c895a..413c36c5c41bbe0169f1c050ccdac040202d66df 100644
--- a/python/paddle/fluid/transpiler/__init__.py
+++ b/python/paddle/fluid/transpiler/__init__.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from distribute_transpiler import DistributeTranspiler
from inference_transpiler import InferenceTranspiler
from memory_optimization_transpiler import memory_optimize, release_memory
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index b45cb987d896bd189531e97eb62bddbbee16069d..a323f8d03613e7c4149812daab4ccb57fb940a7e 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -17,7 +17,7 @@ from __future__ import print_function
import math
import distributed_splitter as splitter
-from .. import core
+from .. import core, framework
from ..framework import Program, default_main_program, \
default_startup_program, \
Variable, Parameter, grad_var_name
@@ -417,7 +417,7 @@ class DistributeTranspiler:
def __append_optimize_op__(op, block, grad_to_block_id):
if self._is_opt_op(op):
self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
- default_main_program())
+ self.origin_program)
else:
self._append_pserver_non_opt_ops(block, op)
diff --git a/tools/manylinux1/README.md b/tools/manylinux1/README.md
index 898e00bd37c7b7bcbcb4a56476ff10c87381e47a..0e5905040175047f5b79939d97a3efcf38992944 100644
--- a/tools/manylinux1/README.md
+++ b/tools/manylinux1/README.md
@@ -28,3 +28,38 @@ git clone https://github.com/paddlepaddle/paddle
cd paddle/tools/manylinux1
REPO=[yourrepo] ./build_all.sh
```
+
+## Build PaddlePaddle for the different Python ABIs
+
+Choose one of the following Python ABI and set the correct environment variables.
+
+- cp27-cp27m
+
+ ```bash
+ export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs2/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs4/lib:}
+ export PATH=/opt/python/cp27-cp27m/bin/:${PATH}
+ export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27m/bin/python
+ -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27m/include/python2.7
+ -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs2/lib/libpython2.7.so"
+ ```
+
+- cp27-cp27mu
+
+ ```bash
+ export LD_LIBRARY_PATH=/opt/_internal/cpython-2.7.11-ucs4/lib:${LD_LIBRARY_PATH#/opt/_internal/cpython-2.7.11-ucs2/lib:}
+ export PATH=/opt/python/cp27-cp27mu/bin/:${PATH}
+ export PYTHON_FLAGS="-DPYTHON_EXECUTABLE:FILEPATH=/opt/python/cp27-cp27mu/bin/python
+ -DPYTHON_INCLUDE_DIR:PATH=/opt/python/cp27-cp27mu/include/python2.7
+ -DPYTHON_LIBRARIES:FILEPATH=/opt/_internal/cpython-2.7.11-ucs4/lib/libpython2.7.so"
+ ```
+
+And then add the `PYTHON_FLAGS` as your cmake flags:
+
+```bash
+cmake ..
+ ${PYTHON_FLAGS} \
+ -DWITH_GPU=OFF \
+ ...
+```
+
+You can find more details about cmake flags at [here](http://www.paddlepaddle.org/docs/develop/documentation/fluid/en/build_and_install/build_from_source_en.html#appendix-build-options)