diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bbe829ac16180088bfa37df66e23f19b021ea3..030bd19b3fd2f561a847bbc4613e5d2030812a92 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
-find_package(Sphinx)
if(NOT CMAKE_CROSSCOMPILING)
find_package(CUDA QUIET)
endif(NOT CMAKE_CROSSCOMPILING)
@@ -226,5 +225,7 @@ if(WITH_PYTHON)
endif()
if(WITH_DOC)
+ find_package(Sphinx REQUIRED)
+ find_python_module(recommonmark REQUIRED)
add_subdirectory(doc)
endif()
diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index 966c0bafd3742862e66d7ff36de86190507a6936..0332e39d14200da1c1af52675f0ccad2c07de405 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -56,6 +56,8 @@ ExternalProject_Add(
GIT_TAG "v0.14"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
+ # Patch MKLDNN to compile with gcc 4.8, the related issue is in intel/mkl-dnn#237.
+ PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/mkldnn.hpp ${MKLDNN_SOURCES_DIR}/src/extern_mkldnn/include/mkldnn.hpp
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
diff --git a/doc/v2/build_and_install/build_from_source_cn.rst b/doc/v2/build_and_install/build_from_source_cn.rst
index 115b92a33888abf1e1be400e1abbb58b632a2976..f846928954dd3a05e11054ce2ff2ff839fbefd4b 100644
--- a/doc/v2/build_and_install/build_from_source_cn.rst
+++ b/doc/v2/build_and_install/build_from_source_cn.rst
@@ -19,8 +19,9 @@
----------------
PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像
-可以在 `这里 `_ 找到。或者
-参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
+可以在 `这里 `_ 找到,您也可以
+在 `这里 `_ 找到 paddle_manylinux_devel
+镜像的编译以及使用方法。或者参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。
diff --git a/doc/v2/build_and_install/build_from_source_en.rst b/doc/v2/build_and_install/build_from_source_en.rst
index 8fef9e7347e8d924026999bfda985381750c6b51..d1b5b88dff81d4c5cee3dd13a7dccbc333ab6a17 100644
--- a/doc/v2/build_and_install/build_from_source_en.rst
+++ b/doc/v2/build_and_install/build_from_source_en.rst
@@ -22,6 +22,8 @@ How To Build
You need to use Docker to build PaddlePaddle
to avoid installing dependencies by yourself. We have several pre-built
Docker images `here `_ ,
+you can also find how to build and use paddle_manylinux_devel Docker image from
+`here `_
Or you can build your own image from source as the optional step below:
.. code-block:: bash
diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h
index d373c48b1a75c5f75c7520b56f230bc2c146b174..a4eb6f706edab9479cbce436311eb96da8845646 100644
--- a/paddle/fluid/framework/operator.h
+++ b/paddle/fluid/framework/operator.h
@@ -192,6 +192,10 @@ class ExecutionContext {
return op_.Attr(name);
}
+ bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
+
+ bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
+
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 20ef7e09f630140c44774147aa727780df6333fa..95e807c0afa45bc4f4feb84d450b2d0584bc3b28 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -58,7 +58,8 @@ ParallelExecutor::ParallelExecutor(
const std::unordered_set &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector &local_scopes, bool allow_op_delay,
- bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
+ bool use_default_grad_scale, bool balance_parameter_opt_between_cards,
+ size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
@@ -80,7 +81,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
- member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
+ auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
+ ncclUniqueId *nccl_id = nullptr;
+ if (nccl_id_var != nullptr) {
+ nccl_id = nccl_id_var->GetMutable();
+ }
+ member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
+ member_->places_, nccl_id, num_trainers, trainer_id));
#endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index b251fc91417a1c00e61e9c3c952460e6268d2819..9e279876cfeef20a1921f8bd1c27046a477b9f56 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -41,7 +41,8 @@ class ParallelExecutor {
const std::string& loss_var_name, Scope* scope,
const std::vector& local_scopes,
bool allow_op_delay, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards);
+ bool balance_parameter_opt_between_cards,
+ size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor();
diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt
index de7becae4d25d48111fea8d2123bc85aef70230a..47929ef7490e5edb246625cb0b3ba507039df27a 100644
--- a/paddle/fluid/inference/analysis/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/CMakeLists.txt
@@ -1 +1,2 @@
-cc_library(dot SRCS dot.cc)
+cc_library(analysis SRCS dot.cc node.cc node.h)
+cc_test(test_node SRCS node_tester.cc DEPS analysis)
diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h
new file mode 100644
index 0000000000000000000000000000000000000000..4423af842d28566fea419b8099efc3bda33787f4
--- /dev/null
+++ b/paddle/fluid/inference/analysis/device.h
@@ -0,0 +1,23 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+enum class Device { CPU, GPU };
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h
index 3359987874f2d74d7e4646baa38790431c4b28fd..4bf1840fdda8508b52d7274a338c5b1c95baf354 100644
--- a/paddle/fluid/inference/analysis/dot.h
+++ b/paddle/fluid/inference/analysis/dot.h
@@ -21,6 +21,7 @@
#include
#include
+#include
#include
#include
diff --git a/paddle/fluid/inference/analysis/dot_tester.cc b/paddle/fluid/inference/analysis/dot_tester.cc
new file mode 100644
index 0000000000000000000000000000000000000000..56ceb9bd5d6f41a601d66f6124fb7b4099c9337e
--- /dev/null
+++ b/paddle/fluid/inference/analysis/dot_tester.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/analysis/dot.h"
+
+#include
+#include
+#include "paddle/fluid/inference/analysis/data_flow_graph.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class DotTester : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ std::vector attrs({{"title", "hello"}});
+ dot.reset(new Dot(attrs));
+ dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")});
+ dot->AddNode("b", {});
+ dot->AddNode("c", {});
+ dot->AddEdge("a", "b", {});
+ dot->AddEdge("b", "c", {});
+ dot->AddEdge("a", "c", {});
+ }
+
+ std::unique_ptr dot;
+};
+
+TEST_F(DotTester, Build) {
+ auto codes = dot->Build();
+ // Output the DOT language code, the generated codes are too long to compare
+ // the string.
+ //
+ // The output is
+ //
+ // digraph G {
+ // title="hello"
+ // node_1
+ // node_2
+ // node_0[label="a" shape="box" color="blue"]
+ // node_0->node_1
+ // node_1->node_2
+ // node_0->node_2
+ // } // end G
+ LOG(INFO) << '\n' << codes;
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h
new file mode 100644
index 0000000000000000000000000000000000000000..b2d06c5d63ff139186710cd963e07b4ba245f9f3
--- /dev/null
+++ b/paddle/fluid/inference/analysis/helper.h
@@ -0,0 +1,74 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+template
+class iterator_range {
+ IteratorT begin_, end_;
+
+ public:
+ template
+ explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
+
+ iterator_range(const IteratorT &begin, const IteratorT &end)
+ : begin_(begin), end_(end) {}
+
+ const IteratorT &begin() const { return begin_; }
+ const IteratorT &end() const { return end_; }
+};
+
+/*
+ * An registry helper class, with its records keeps the order they registers.
+ */
+template
+class OrderedRegistry {
+ public:
+ T *Register(const std::string &name, T *x) {
+ PADDLE_ENFORCE(!dic_.count(name));
+ dic_[name] = data_.size();
+ data_.emplace_back(std::unique_ptr(x));
+ return data_.back().get();
+ }
+
+ T *Lookup(const std::string &name) {
+ auto it = dic_.find(name);
+ if (it == dic_.end()) return nullptr;
+ return data_[it->second].get();
+ }
+
+ protected:
+ std::unordered_map dic_;
+ std::vector> data_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
+
+#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
+ \
+ type__(const type__ &) = delete; \
+ \
+ void operator=(const type__ &) = delete;
diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc
new file mode 100644
index 0000000000000000000000000000000000000000..fe060526080b1ee01aa98f2ff06fb2191eddf9da
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+#include "glog/logging.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+std::vector Value::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "box"),
+ Dot::Attr("fillcolor", "red")});
+}
+
+std::vector Function::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "diamond"),
+ Dot::Attr("fillcolor", "yellow")});
+}
+
+Node *NodeMap::Create(Node::Type type) {
+ switch (type) {
+ case Node::Type::kFunction:
+ nodes_.emplace_back(new Function);
+ break;
+ case Node::Type::kValue:
+ nodes_.emplace_back(new Value);
+ break;
+ default:
+ PADDLE_THROW("Not supported node type.");
+ }
+ nodes_.back()->id_ = size() - 1;
+ return nodes_.back().get();
+}
+
+Node *NodeMap::GetMutable(size_t id) {
+ PADDLE_ENFORCE_GT(size(), id);
+ return nodes_[id].get();
+}
+
+const Node &NodeMap::Get(size_t id) const {
+ PADDLE_ENFORCE_GT(size(), id);
+ return *nodes_[id].get();
+}
+
+void NodeMap::Delete(size_t id) {
+ PADDLE_ENFORCE_LT(id, size());
+ nodes_[id]->SetDeleted();
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h
new file mode 100644
index 0000000000000000000000000000000000000000..59ba977798481684114d1189056be00bbb7777cf
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.h
@@ -0,0 +1,234 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+/*
+ * This file defines the Node class and its subclasses. A Node is the basis
+ * analysis element in a computation graph.
+ * There are basically two kinds of nodes, the function node and value node.
+ */
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/inference/analysis/device.h"
+#include "paddle/fluid/inference/analysis/dot.h"
+#include "paddle/fluid/inference/analysis/helper.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class NodeMap;
+
+/*
+ * Node Representation.
+ *
+ * This is a very important class for analysis. It is the base class of all
+ * nodes computed by a program that may be used as operands to other nodes.
+ * Node is the super class of other important classes such as Function and
+ * Value, some nodes can have a name.
+ */
+class Node {
+ public:
+ // Node type. NOTE the new node types should add here.
+ enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock };
+
+ Node() = default;
+
+ struct Attr;
+
+ // Cast to a subclass type, Function for example.
+ template
+ Subclass &As() {
+ return *dynamic_cast(this);
+ }
+
+ // Formatted representation of this Node.
+ virtual std::string repr() const {
+ return name() + "(" + std::to_string(id()) + ")";
+ }
+
+ // DOT node representation. One Node type can customize its own node
+ // representation.
+ virtual std::vector dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled")});
+ }
+
+ // Get an additional attribute and convert it to T data type. NOTE this will
+ // silently create a new attribute if not exists.
+ Attr &attr(const std::string &name) { return attrs_[name]; }
+
+ int id() const { return id_; }
+
+ bool deleted() const { return deleted_; }
+ void SetDeleted() { deleted_ = true; }
+
+ void SetName(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
+
+ void SetType(Type type) { type_ = type; }
+ Type type() const { return type_; }
+
+ void *extra_info() const { return extra_info_; }
+ void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
+
+ // Input links.
+ std::vector inlinks;
+ // Output links.
+ std::vector outlinks;
+
+ // A helper class to maintain the status from Pass.
+ // TODO(superjomn) add a checker here to ensure the T is primary.
+ struct Attr {
+ // NOTE T should be a primary type or a struct combined by several primary
+ // types.
+ // NOTE the STL containers should not use here.
+ // Some usages
+ // Attr attr;
+ // T data;
+ // attr.data.assign((char*)data, sizeof(data));
+
+ bool &Bool() { return As(); }
+ float &Float() { return As(); }
+ int32_t &Int32() { return As(); }
+ int64_t &Int64() { return As(); }
+
+ private:
+ template
+ T &As() {
+ // init storage in the first usage.
+ if (data_.empty()) {
+ VLOG(4) << "resize data to " << sizeof(T);
+ type_hash_ = typeid(T).hash_code();
+ data_.resize(sizeof(T));
+ }
+ PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), "type not matched");
+ PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
+ return *reinterpret_cast(&data_[0]);
+ }
+
+ private:
+ std::string data_;
+ size_t type_hash_{std::numeric_limits::max()};
+ };
+
+ virtual ~Node() {}
+
+ friend class NodeMap;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Node);
+
+ protected:
+ // The id number not the name is a node's unique identifier in the computation
+ // graph.
+ int id_{-1};
+ std::string name_;
+ Type type_{Type::kNone};
+ // Mark this node is deleted by some pass.
+ bool deleted_{false};
+
+ void *extra_info_;
+
+ mutable std::unordered_map attrs_;
+};
+
+class Function;
+/*
+ * Value represents a value node, it has some attributes including dims, data
+ * type and so on.
+ */
+class Value : public Node {
+ public:
+ enum class DataType { kInt32, kInt64, kFloat32, kFloat64 };
+ using Dims = std::vector;
+
+ void SetDataType(DataType data_type) { data_type_ = data_type; }
+ DataType data_type() const { return data_type_; }
+
+ void SetDims(const Dims &dims) { dims_ = dims; }
+ const Dims &dims() const { return dims_; }
+
+ Device device() const { return device_; }
+ void SetDevice(Device device) { device_ = device; }
+
+ std::vector dot_attrs() const override;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Value);
+
+ protected:
+ Value() { SetType(Node::Type::kValue); }
+ friend class NodeMap;
+
+ private:
+ DataType data_type_;
+ Dims dims_;
+ Device device_;
+};
+
+/*
+ * Function represents any kind of executable concepts that takes several Values
+ * as input, and outputs several Values.
+ */
+class Function : public Node {
+ public:
+ std::vector dot_attrs() const override;
+
+ // Get the operator's type from Desc.
+ const std::string &func_type() const { return func_type_; }
+ // Set the operator's type.
+ void SetFuncType(const std::string &func_type) { func_type_ = func_type; }
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Function);
+
+ protected:
+ std::string func_type_;
+ Function() { SetType(Node::Type::kFunction); }
+ friend class NodeMap;
+};
+
+/*
+ * FunctionBlock is a Node that contains a sub-graph multiple Node.
+ */
+struct FunctionBlock : public Node {
+ std::string repr() const override { return "block-" + std::to_string(id()); }
+ std::vector subgraph;
+};
+
+class NodeMap {
+ public:
+ // Create a new node with type.
+ Node *Create(Node::Type type);
+
+ // Get a node by its id.
+ Node *GetMutable(size_t id);
+
+ const Node &Get(size_t id) const;
+
+ void Delete(size_t id);
+
+ const std::vector> &nodes() { return nodes_; }
+
+ size_t size() const { return nodes_.size(); }
+
+ private:
+ std::vector> nodes_;
+ std::unordered_map map_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node_tester.cc b/paddle/fluid/inference/analysis/node_tester.cc
new file mode 100644
index 0000000000000000000000000000000000000000..47fea0fdff808c930ca73edb25f5b16fef397e9a
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node_tester.cc
@@ -0,0 +1,34 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+
+#include
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+TEST(Node, Attr) {
+ // Node is an abstract class, use Value instead for they share the same Attr
+ // logic.
+ NodeMap nodes;
+ auto* node = nodes.Create(Node::Type::kValue);
+ node->attr("v0").Int32() = 2008;
+ ASSERT_EQ(node->attr("v0").Int32(), 2008);
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h
index de0375551e16ec53b90414c7446234fda98bf706..ce2b8161715a3fa2278ce950dbac82c6d0042bef 100644
--- a/paddle/fluid/inference/engine.h
+++ b/paddle/fluid/inference/engine.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
@@ -58,8 +59,8 @@ class EngineBase {
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
- int max_size; // buffer allocated space.
- int size; // data size.
+ size_t max_size; // buffer allocated space.
+ size_t size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt
index 677b3e04af8e7f5662a15fb32e3b03f45d262733..b52d083f280e5e7713600a7b748dedd37aca0a1e 100644
--- a/paddle/fluid/inference/tensorrt/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt
@@ -1,5 +1,4 @@
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
-
add_subdirectory(convert)
diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
index 5178c54c08400125d190078dac6c52d021f8488b..4fb4511d99179e4ea14cde66feb13bc9e114581a 100644
--- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
@@ -1,4 +1,4 @@
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
-nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
+nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
index 543784289cfc51048057e467d36fdd1f334eb903..6297051e5a30f1daa512d25d5aa3ab3b2f79f1d1 100644
--- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
@@ -21,15 +21,18 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
+ // Here the two nullptr looks strange, that's because the
+ // framework::OpDesc's constructor is strange.
+ framework::OpDesc op_desc(op, nullptr, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu";
const nvinfer1::ITensor* input_tensor =
- engine_->GetITensor(op.Input("X")[0]);
+ engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast(input_tensor),
nvinfer1::ActivationType::kRELU);
- engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0));
+ engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
index 431500b90e144e2a30fe705b72e93452f806ca65..209936c3bafb0d31546856dc36c1b48053a0634b 100644
--- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
index 32e8631fde3f748669d2008b4a060455a37e154e..854f434d93e81237dc85c5df62debcf3b3824b78 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
@@ -23,26 +23,42 @@ namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
-class DefaultInputConverter : public EngineInputConverter {
+class DefaultIOConverter : public EngineIOConverter {
public:
- DefaultInputConverter() {}
+ DefaultIOConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
- PADDLE_ENFORCE_LE(in.memory_size(), max_size);
+ PADDLE_ENFORCE(stream_ != nullptr);
const auto& place = in.place();
+ size_t size = in.memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
if (is_cpu_place(place)) {
- PADDLE_ENFORCE(stream_ != nullptr);
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToDevice, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToHost, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyDeviceToDevice, *stream_));
+ } else {
+ PADDLE_THROW("Unknown device for converter");
+ }
+ cudaStreamSynchronize(*stream_);
+ }
+ // NOTE in is GPU memory.
+ virtual void operator()(const void* in, LoDTensor* out,
+ size_t max_size) override {
+ PADDLE_ENFORCE(in != nullptr);
+ PADDLE_ENFORCE(stream_ != nullptr);
+ const auto& place = out->place();
+ size_t size = out->memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
+ if (is_cpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToHost, *stream_));
+ } else if (is_gpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToDevice, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
@@ -50,7 +66,8 @@ class DefaultInputConverter : public EngineInputConverter {
}
};
-REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
+// fluid LodTensor <-> tensorrt ITensor
+REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter);
} // namespace tensorrt
} // namespace inference
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h
index 8972dae92be2c2d261a13c48d98e675f64e51d31..71c48e085d25d2bc6720d93735f661f9e3af7b40 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
@@ -25,43 +26,57 @@ namespace tensorrt {
using framework::LoDTensor;
/*
- * Convert Input from Fluid to an Engine.
- * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
- * most cases just need to copy the data.
+ * Convert Input from Fluid to TensorRT Engine.
+ * Convert Output from TensorRT Engine to Fluid.
+ *
+ * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row
+ * major,
+ * so in the default case just need to copy the data.
*/
-class EngineInputConverter {
+class EngineIOConverter {
public:
- EngineInputConverter() {}
+ EngineIOConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
+ virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
- static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
- size_t max_size, cudaStream_t* stream) {
+ static void ConvertInput(const std::string& op_type, const LoDTensor& in,
+ void* out, size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
- auto* converter = Registry::Lookup(
- in_op_type, "default" /* default_type */);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
- virtual ~EngineInputConverter() {}
+ static void ConvertOutput(const std::string& op_type, const void* in,
+ LoDTensor* out, size_t max_size,
+ cudaStream_t* stream) {
+ PADDLE_ENFORCE(stream != nullptr);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
+ PADDLE_ENFORCE_NOT_NULL(converter);
+ converter->SetStream(stream);
+ (*converter)(in, out, max_size);
+ }
+
+ virtual ~EngineIOConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
+#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \
+ struct trt_io_##op_type__##_converter { \
+ trt_io_##op_type__##_converter() { \
+ Registry::Register(#op_type__); \
+ } \
+ }; \
+ trt_io_##op_type__##_converter trt_io_##op_type__##_converter__;
+
} // namespace tensorrt
} // namespace inference
} // namespace paddle
-
-#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
- struct trt_input_##in_op_type__##_converter { \
- trt_input_##in_op_type__##_converter() { \
- ::paddle::inference::Registry::Register< \
- Converter__>(#in_op_type__); \
- } \
- }; \
- trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
index f9834ab156c9dcc11f4e89075b7bf5457cf00268..3ca58b139bd3af1947ae7f063060e11d2ea7d577 100644
--- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h
index 77c788550b2c7df1f483b926661789b2a54d8fff..abc9ebf472498f6653d5bb1113ae2f3ce7e5a923 100644
--- a/paddle/fluid/inference/tensorrt/convert/op_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h
@@ -31,10 +31,10 @@ namespace tensorrt {
class OpConverter {
public:
OpConverter() {}
- virtual void operator()(const framework::OpDesc& op) {}
+ virtual void operator()(const framework::proto::OpDesc& op) {}
- void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
- std::string type = op.Type();
+ void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
+ std::string type = op.type();
auto* it = Registry::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
@@ -42,14 +42,16 @@ class OpConverter {
}
// convert fluid op to tensorrt layer
- void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
+ void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine);
}
// convert fluid block to tensorrt network
- void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
- for (auto op : block.AllOps()) {
- OpConverter::Run(*op, engine);
+ void ConvertBlock(const framework::proto::BlockDesc& block,
+ TensorRTEngine* engine) {
+ for (size_t i = 0; i < block.ops_size(); i++) {
+ const auto& op = block.ops(i);
+ OpConverter::Run(op, engine);
}
}
diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
index 23e3435c21725328d3765fae0d158a83ac21478b..ec33f97c8240dfc09a203d68599bffe78a4abb12 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
@@ -26,7 +27,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
-void Compare(float input, float expect) {
+void Compare(const std::string op_type, float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
@@ -35,6 +36,7 @@ void Compare(float input, float expect) {
auto x_var = scope.Var("X");
auto x_tensor = x_var->GetMutable();
x_tensor->Resize({1, 1});
+ x_tensor->mutable_data(place);
std::vector init;
init.push_back(input);
framework::TensorFromVector(init, ctx, x_tensor);
@@ -45,14 +47,15 @@ void Compare(float input, float expect) {
out_tensor->mutable_data(place);
framework::OpDesc op_desc;
- op_desc.SetType("relu");
+ op_desc.SetType(op_type);
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
- auto relu_op = framework::OpRegistry::CreateOp(op_desc);
+ auto op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
- relu_op->Run(scope, place);
+ op->Run(scope, place);
+ // get fluid output
std::vector out1;
framework::TensorToVector(*out_tensor, ctx, &out1);
@@ -63,21 +66,28 @@ void Compare(float input, float expect) {
engine->InitNetwork();
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
-
+ // convert op
OpConverter op_converter;
- op_converter.ConvertOp(op_desc, engine);
+ op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
- engine->SetInputFromCPU("X", &input, 1 * sizeof(float));
- // run tensorrt op
+ // convert LoDTensor to ITensor
+ size_t size = x_tensor->memory_size();
+ EngineIOConverter::ConvertInput(op_type, *x_tensor,
+ engine->buffer("X").buffer, size, &stream);
+ // run tensorrt Outp
engine->Execute(1);
-
- float out2;
- engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float));
-
- ASSERT_EQ(out1[0], out2);
+ // convert ITensor to LoDTensor
+ EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer,
+ out_tensor, size, &stream);
+ // get tensorrt output
+ std::vector out2;
+ framework::TensorToVector(*out_tensor, ctx, &out2);
+
+ // compare
+ ASSERT_EQ(out1[0], out2[0]);
ASSERT_EQ(out1[0], expect);
delete engine;
@@ -85,8 +95,8 @@ void Compare(float input, float expect) {
}
TEST(OpConverter, ConvertRelu) {
- Compare(1, 1); // relu(1) = 1
- Compare(-5, 0); // relu(-5) = 0
+ Compare("relu", 1, 1); // relu(1) = 1
+ Compare("relu", -5, 0); // relu(-5) = 0
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
index afcc516e6b76d58e37ce0e60746704cf3933fac7..8f91309a0a00d5131268f026c319e25ba3cb964a 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
@@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
-#include
-
namespace paddle {
namespace inference {
namespace tensorrt {
-class EngineInputConverterTester : public ::testing::Test {
- public:
- void SetUp() override { tensor.Resize({10, 10}); }
+void IOConverterTester(const platform::DeviceContext& ctx) {
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
- framework::LoDTensor tensor;
-};
+ // init fluid in_tensor
+ framework::LoDTensor in_tensor;
+ in_tensor.Resize({10, 10});
+ auto place = ctx.GetPlace();
+ in_tensor.mutable_data(place);
+ std::vector init;
+ for (int64_t i = 0; i < 10 * 10; ++i) {
+ init.push_back(i);
+ }
+ framework::TensorFromVector(init, ctx, &in_tensor);
-TEST_F(EngineInputConverterTester, DefaultCPU) {
+ // init tensorrt buffer
void* buffer;
- tensor.mutable_data(platform::CPUPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+ size_t size = in_tensor.memory_size();
+ ASSERT_EQ(cudaMalloc(&buffer, size), 0);
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+ // convert fluid in_tensor to tensorrt buffer
+ EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream);
+
+ // convert tensorrt buffer to fluid out_tensor
+ framework::LoDTensor out_tensor;
+ out_tensor.Resize({10, 10});
+ out_tensor.mutable_data(place);
+ EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream);
+
+ // compare in_tensor and out_tensor
+ std::vector result;
+ framework::TensorToVector(out_tensor, ctx, &result);
+ EXPECT_EQ(init.size(), result.size());
+ for (size_t i = 0; i < init.size(); i++) {
+ EXPECT_EQ(init[i], result[i]);
+ }
+ cudaStreamDestroy(stream);
}
-TEST_F(EngineInputConverterTester, DefaultGPU) {
- void* buffer;
- tensor.mutable_data(platform::CUDAPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+TEST(EngineIOConverterTester, DefaultCPU) {
+ platform::CPUPlace place;
+ platform::CPUDeviceContext ctx(place);
+ IOConverterTester(ctx);
+}
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+TEST(EngineIOConverterTester, DefaultGPU) {
+ platform::CUDAPlace place;
+ platform::CUDADeviceContext ctx(place);
+ IOConverterTester(ctx);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
index aa5fb726f1129eda65a6f39791330b795aad660d..8d66543eb7637c5a8ae670b89ef5996954ba2e7b 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d");
OpConverter converter;
- converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/);
+ converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
index c4fd1e298b0daea85db2a407d04ad2d7bcdee0f0..60c761c5281e2f535aab0200c93fb738addcdb87 100644
--- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
+++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
@@ -16,7 +16,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
-DEFINE_string(data_set, "cifar10", "Data set to test");
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
@@ -35,19 +34,19 @@ TEST(inference, image_classification) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+ const bool is_combined = false;
+ std::vector> feed_target_shapes =
+ GetFeedTargetShapes(dirname, is_combined);
+
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
- if (FLAGS_data_set == "cifar10") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32},
- static_cast(0), static_cast(1));
- } else if (FLAGS_data_set == "imagenet") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224},
- static_cast(0), static_cast(1));
- } else {
- LOG(FATAL) << "Only cifar10 or imagenet is supported.";
- }
-
+ feed_target_shapes[0][0] = FLAGS_batch_size;
+ paddle::framework::DDim input_dims =
+ paddle::framework::make_ddim(feed_target_shapes[0]);
+ LOG(INFO) << input_dims;
+ SetupTensor(&input, input_dims, static_cast(0),
+ static_cast(1));
std::vector cpu_feeds;
cpu_feeds.push_back(&input);
@@ -60,7 +59,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
LOG(INFO) << output1.dims();
}
@@ -73,7 +72,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- GPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined);
LOG(INFO) << output2.dims();
if (!FLAGS_skip_cpu) {
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index af2a7a5620487a10c1df6152fc4e4bf67b150752..b02e5c99f00eaf03c3753e43575cbc67e834774e 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
+std::unique_ptr InitProgram(
+ paddle::framework::Executor* executor, paddle::framework::Scope* scope,
+ const std::string& dirname, const bool is_combined = false) {
+ std::unique_ptr inference_program;
+ if (is_combined) {
+ // All parameters are saved in a single file.
+ // Hard-coding the file names of program and parameters in unittest.
+ // The file names should be consistent with that used in Python API
+ // `fluid.io.save_inference_model`.
+ std::string prog_filename = "__model_combined__";
+ std::string param_filename = "__params_combined__";
+ inference_program =
+ paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
+ dirname + "/" + param_filename);
+ } else {
+ // Parameters are saved in separate files sited in the specified
+ // `dirname`.
+ inference_program = paddle::inference::Load(executor, scope, dirname);
+ }
+ return inference_program;
+}
+
+std::vector> GetFeedTargetShapes(
+ const std::string& dirname, const bool is_combined = false) {
+ auto place = paddle::platform::CPUPlace();
+ auto executor = paddle::framework::Executor(place);
+ auto* scope = new paddle::framework::Scope();
+
+ auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
+ auto& global_block = inference_program->Block(0);
+
+ const std::vector& feed_target_names =
+ inference_program->GetFeedTargetNames();
+ std::vector> feed_target_shapes;
+ for (size_t i = 0; i < feed_target_names.size(); ++i) {
+ auto* var = global_block.FindVar(feed_target_names[i]);
+ std::vector var_shape = var->GetShape();
+ feed_target_shapes.push_back(var_shape);
+ }
+
+ delete scope;
+ return feed_target_shapes;
+}
+
template
void TestInference(const std::string& dirname,
const std::vector& cpu_feeds,
@@ -124,22 +168,7 @@ void TestInference(const std::string& dirname,
paddle::platform::RecordEvent record_event(
"init_program",
paddle::platform::DeviceContextPool::Instance().Get(place));
-
- if (is_combined) {
- // All parameters are saved in a single file.
- // Hard-coding the file names of program and parameters in unittest.
- // The file names should be consistent with that used in Python API
- // `fluid.io.save_inference_model`.
- std::string prog_filename = "__model_combined__";
- std::string param_filename = "__params_combined__";
- inference_program = paddle::inference::Load(
- &executor, scope, dirname + "/" + prog_filename,
- dirname + "/" + param_filename);
- } else {
- // Parameters are saved in separate files sited in the specified
- // `dirname`.
- inference_program = paddle::inference::Load(&executor, scope, dirname);
- }
+ inference_program = InitProgram(&executor, scope, dirname, is_combined);
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index c14a2b7786f9f7c06d59479d3bbce9c5d542e495..d38a9ce58726a1d045d6905354b0b592166c0110 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -186,6 +186,11 @@ endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE)
+ if(WITH_GPU)
+ op_library(gen_nccl_id_op DEPS nccl_common)
+ else()
+ set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
+ endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
+ cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else()
- set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
+ set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index 661dfa69fe1580ff3890f12defcd124225be0c06..ae60ab15325ef101feb7270a4f5d840cb2112be0 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
+ s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h
index f6229b71bc01a6de51f50f5fe880ada6e15e74dd..dabce7414d2f0dca74193f1cd10c341793c10ec9 100644
--- a/paddle/fluid/operators/detail/grpc_client.h
+++ b/paddle/fluid/operators/detail/grpc_client.h
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
- explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; }
+ explicit BaseProcessor(std::shared_ptr ch) {
+ context_ = nullptr;
+ }
virtual ~BaseProcessor() {}
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
- RequestSendCallBack response_call_back_ = NULL;
+ RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function
diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h
index 7f9cae21ccca8dd51f9fbe98148d01a51ac6eb84..18f1bc53d0f561f412a5bbbe018bc3d427ac9ef9 100644
--- a/paddle/fluid/operators/detail/grpc_server.h
+++ b/paddle/fluid/operators/detail/grpc_server.h
@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
+ ~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto
index fffa9ae7a43ea5cd7b2bda6fbbf6ef9f7d23009d..9478c5702bcbf99fc88207b8c4843dbccf8a5925 100644
--- a/paddle/fluid/operators/detail/send_recv.proto
+++ b/paddle/fluid/operators/detail/send_recv.proto
@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
+ NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc
index 1a8a1af20fa446dbd537944409ef0ca1e3e9116f..07c43554bc6a0d71d688a5a5772d0ab3d2de319a 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.cc
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include
#include // NOLINT
@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else if (var->IsType()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
+#ifdef PADDLE_WITH_CUDA
+ } else if (var->IsType()) {
+ request.set_type(::sendrecv::NCCL_ID);
+#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
+// NCCLID is copied directly to the message, return bytebuffer
+// with only one slice if serializing NCCLID.
+#ifdef PADDLE_WITH_CUDA
+ if (var->IsType()) {
+ e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
+ NCCL_UNIQUE_ID_BYTES);
+ const ncclUniqueId& uid = var->Get();
+ e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
+
+ // for serialize NCCL_ID
+ ::grpc::Slice slices(e.size());
+ memcpy(const_cast(slices.begin()), e.data(), e.size());
+ ::grpc::ByteBuffer tmp(&slices, 1);
+ msg->Swap(&tmp);
+ return;
+ }
+#endif
+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc
index 99602a05d023f30c2eed8df25e7534fdc9ef2ced..462e303096e609c6797ca8cc16266ec3621623fc 100644
--- a/paddle/fluid/operators/detail/variable_response.cc
+++ b/paddle/fluid/operators/detail/variable_response.cc
@@ -17,6 +17,9 @@
#include
#include
#include
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
- meta_.type() == sendrecv::LOD_TENSOR) &&
+ meta_.type() == sendrecv::LOD_TENSOR ||
+ meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
"meta info should be got first!");
@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
return tag;
}
+ if (meta_.type() == sendrecv::NCCL_ID) {
+#ifdef PADDLE_WITH_CUDA
+ auto* var = scope_->FindVar(meta_.varname());
+ if (var != nullptr) {
+ ncclUniqueId* id = var->GetMutable();
+ if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
+ num_bytes)) {
+ return tag;
+ }
+ }
+ break;
+#else
+ PADDLE_THROW("Not compiled with CUDA!");
+#endif
+ }
+
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a5678f63466d368b3dd59380c18f9625cabd368b
--- /dev/null
+++ b/paddle/fluid/operators/gen_nccl_id_op.cc
@@ -0,0 +1,128 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/threadpool.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/detail/grpc_server.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+
+namespace paddle {
+namespace operators {
+
+class GenNCCLIdOp : public framework::OperatorBase {
+ public:
+ GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
+ const framework::VariableNameMap& outputs,
+ const framework::AttributeMap& attrs)
+ : OperatorBase(type, inputs, outputs, attrs) {}
+
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ // put nccl id in CPUPlace
+ auto& dev_ctx = *pool.Get(platform::CPUPlace());
+ int trainer_id = Attr("trainer_id");
+ framework::Scope& local_scope = scope.NewScope();
+
+ if (trainer_id == 0) {
+ GenerateAndSend(&local_scope, dev_ctx);
+ } else {
+ GetIdByServer(&local_scope, dev_ctx);
+ }
+ }
+
+ private:
+ void GenerateAndSend(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ auto var = scope->FindVar(NCCL_ID_VARNAME);
+ PADDLE_ENFORCE_NOT_NULL(var);
+ auto id = var->GetMutable();
+ PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
+
+ std::vector endpoint_list =
+ Attr>("endpoint_list");
+ detail::RPCClient client;
+ for (auto& ep : endpoint_list) {
+ VLOG(3) << "sending nccl id to " << ep;
+ client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
+ }
+ client.Wait();
+ VLOG(3) << "sending completed...";
+ }
+
+ void GetIdByServer(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ std::string endpoint = Attr("endpoint");
+ // NOTE: Can not use unique_ptr here because the default
+ // deleter will call GRPC Server's base class's dtor and
+ // that will cause a wired crash.
+ detail::AsyncGRPCServer rpc_service(endpoint, true);
+ framework::ProgramDesc empty_program;
+ framework::Executor executor(dev_ctx.GetPlace());
+ rpc_service.SetScope(scope);
+ rpc_service.SetDevCtx(&dev_ctx);
+ rpc_service.SetProgram(&empty_program);
+ rpc_service.SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
+ rpc_service.SetCond(0);
+ VLOG(3) << "start getting nccl id from trainer 0...";
+ auto recv = rpc_service.Get();
+ VLOG(3) << "got nccl id and stop server...";
+ rpc_service.ShutDown();
+ VLOG(3) << "rpc server stopped";
+ server_thread.join();
+ }
+};
+
+class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
+ AddComment(R"DOC(
+GenNCCLId operator
+
+For trainer 0: generate a new UniqueId and send it to all the other trainers.
+For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
+)DOC");
+ AddAttr("endpoint",
+ "(string), e.g. 127.0.0.1:6175 "
+ "current listen endpoint");
+ AddAttr>(
+ "endpoint_list",
+ "['trainer1_ip:port', 'trainer2_ip:port', ...] "
+ "list of trainer endpoints start from trainer 1")
+ .SetDefault({});
+ AddAttr("trainer_id",
+ "(int default 0) "
+ "The index of the trainer in distributed training.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h
index 0abda999a52bcbb94e6503692bd11aff26e849ba..62e6307ae9f4236a38c49daaf09fc05c54268159 100644
--- a/paddle/fluid/operators/math/sequence2batch.h
+++ b/paddle/fluid/operators/math/sequence2batch.h
@@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
- PADDLE_ENFORCE_GT(lods.size(), 2UL);
- PADDLE_ENFORCE_EQ(lods[1].size(),
- static_cast(lod_tensor.dims()[0]));
+ PADDLE_ENFORCE_GT(lods.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ lods[1].size(), static_cast(lod_tensor.dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_batch;
to_batch(context, lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
- auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
+ auto lod = lods[0];
+
std::vector seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
@@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod();
- PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
- PADDLE_ENFORCE_EQ(in_lod[1].size(),
- static_cast(lod_tensor->dims()[0]));
+ PADDLE_ENFORCE_GT(in_lod.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ in_lod[1].size(), static_cast(lod_tensor->dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false);
}
diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h
index ccd7063fe69e0f21b4d2a821bb70902b39c9b9de..3dd8c7c11eca241e747bfa129962032d882ce44c 100644
--- a/paddle/fluid/operators/reshape_op.h
+++ b/paddle/fluid/operators/reshape_op.h
@@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
if (unk_dim_idx != -1) {
- output_shape[unk_dim_idx] = -in_size / capacity;
- // in_size < 0 and is un-determinate in compile time, skip the check,
- // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
- // capacity = -24, in_size = -8, output_shape[0] = 0
- // the following check will fail.
if (in_size > 0) {
+ // in_size < 0 and is un-determinate in compile time, skip the check,
+ // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
+ // capacity = -24, in_size = -8, output_shape[0] = 0
+ // the following check will fail.
+ output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
+ } else {
+ output_shape[unk_dim_idx] = -1;
}
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
@@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output("Out");
auto *in = ctx.Input("X");
- auto *shape_tensor = ctx.Input("Shape");
+
+ auto *shape_tensor = ctx.HasInput("Shape")
+ ? ctx.Input("Shape")
+ : nullptr;
framework::DDim out_dims = out->dims();
diff --git a/paddle/fluid/operators/test_send_nccl_id.cc b/paddle/fluid/operators/test_send_nccl_id.cc
new file mode 100644
index 0000000000000000000000000000000000000000..bbae1d54aa3524fd45cb8ab13c86df8d54b8e643
--- /dev/null
+++ b/paddle/fluid/operators/test_send_nccl_id.cc
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include // NOLINT
+
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/operator.h"
+#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/listen_and_serv_op.h"
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/operators/math/selected_rows_functor.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+#include "paddle/fluid/string/printf.h"
+
+USE_NO_KERNEL_OP(listen_and_serv);
+
+namespace f = paddle::framework;
+namespace p = paddle::platform;
+namespace m = paddle::operators::math;
+namespace detail = paddle::operators::detail;
+namespace string = paddle::string;
+
+std::unique_ptr rpc_service;
+
+void StartServer(std::atomic* initialized) {
+ f::Scope scope;
+ p::CPUPlace place;
+ scope.Var(NCCL_ID_VARNAME);
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
+
+ f::ProgramDesc empty_program;
+ f::Executor executor(dev_ctx.GetPlace());
+ rpc_service->SetScope(&scope);
+ rpc_service->SetDevCtx(&dev_ctx);
+ rpc_service->SetProgram(&empty_program);
+ rpc_service->SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
+ *initialized = true;
+ rpc_service->SetCond(0);
+ auto recv = rpc_service->Get();
+ LOG(INFO) << "got nccl id and stop server...";
+ rpc_service->ShutDown();
+ server_thread.join();
+}
+
+TEST(SendNcclId, Normal) {
+ std::atomic initialized{false};
+ std::thread server_thread(StartServer, &initialized);
+ while (!initialized) {
+ }
+ // wait server to start
+ // sleep(2);
+ rpc_service->WaitServerReady();
+
+ f::Scope scope;
+ p::CPUPlace place;
+ p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
+ auto& dev_ctx = *pool.Get(p::CPUPlace());
+
+ auto var = scope.Var(NCCL_ID_VARNAME);
+ // var->SetType(f::proto::VarType_Type_RAW);
+ auto id = var->GetMutable();
+ p::dynload::ncclGetUniqueId(id);
+
+ int port = rpc_service->GetSelectedPort();
+ std::string ep = string::Sprintf("127.0.0.1:%d", port);
+ detail::RPCClient client;
+
+ client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
+ client.Wait();
+ server_thread.join();
+ auto* ptr = rpc_service.release();
+ delete ptr;
+}
diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h
index 0013597fd516d15c7d502370eec77e1a6a5dca88..e30c1a9ebf08365a9856fb32b1ce5790869e2b33 100644
--- a/paddle/fluid/platform/nccl_helper.h
+++ b/paddle/fluid/platform/nccl_helper.h
@@ -14,12 +14,15 @@
#pragma once
+#include
#include // NOLINT
#include
#include
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
+#define NCCL_ID_VARNAME "NCCLID"
+
namespace paddle {
namespace platform {
@@ -73,7 +76,9 @@ struct NCCLContextMap {
std::unordered_map contexts_;
std::vector order_;
- explicit NCCLContextMap(const std::vector &places) {
+ explicit NCCLContextMap(const std::vector &places,
+ ncclUniqueId *nccl_id = nullptr,
+ size_t num_trainers = 1, size_t trainer_id = 0) {
PADDLE_ENFORCE(!places.empty());
order_.reserve(places.size());
for (auto &p : places) {
@@ -85,18 +90,34 @@ struct NCCLContextMap {
order_.size(), contexts_.size(),
"NCCL Context Map does not support contain two or more same device");
- if (places.size() > 1) {
- std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ if (places.size() <= 1) {
+ return;
+ }
+ std::unique_ptr comms(new ncclComm_t[order_.size()]);
+ // if pass nccl_id here, can assume we are doing multi node training
+ if (nccl_id == nullptr) {
+ std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
+ comms.get(), static_cast(order_.size()), order_.data()));
+ } else {
+ PADDLE_ENFORCE_GT(num_trainers, 1);
+ // TODO(wuyi): need to ensure each node have same number of GPUs
{
- std::lock_guard guard(NCCLGroupGuard::NCCLMutex());
- PADDLE_ENFORCE(platform::dynload::ncclCommInitAll(
- comms.get(), static_cast(order_.size()), order_.data()));
- }
- int i = 0;
- for (auto &dev_id : order_) {
- contexts_.at(dev_id).comm_ = comms[i++];
+ int nranks = num_trainers * order_.size();
+ NCCLGroupGuard gurad;
+ for (auto &gpu_id : order_) {
+ int rank = trainer_id * order_.size() + gpu_id;
+ VLOG(3) << "init nccl rank: " << rank << " nranks: " << nranks;
+ PADDLE_ENFORCE(cudaSetDevice(gpu_id));
+ PADDLE_ENFORCE(platform::dynload::ncclCommInitRank(
+ comms.get() + gpu_id, nranks, *nccl_id, rank));
+ }
}
}
+ int i = 0;
+ for (auto &dev_id : order_) {
+ contexts_.at(dev_id).comm_ = comms[i++];
+ }
}
NCCLContextMap(const NCCLContextMap &other) = delete;
diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc
index 3e2eed31b446b83843fba943e4f2bc9e3787d7f6..b62291a99d34457dd17bf2bcafc1fc611419f086 100644
--- a/paddle/fluid/pybind/pybind.cc
+++ b/paddle/fluid/pybind/pybind.cc
@@ -503,12 +503,13 @@ All parameter, weight, gradient are variables in Paddle.
const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, std::vector &local_scopes,
bool allow_op_delay, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards) {
+ bool balance_parameter_opt_between_cards, size_t num_trainers,
+ size_t trainer_id) {
new (&self) ParallelExecutor(
num_threads, use_event, places, params, bcast_vars,
main_program, loss_var_name, scope, local_scopes,
allow_op_delay, use_default_grad_scale,
- balance_parameter_opt_between_cards);
+ balance_parameter_opt_between_cards, num_trainers, trainer_id);
})
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
// NOTE: even we return a vec* to Python use reference policy.
diff --git a/patches/mkldnn.hpp b/patches/mkldnn.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..fe01ad8a10ebd223da75bf857617c4ad36b2634e
--- /dev/null
+++ b/patches/mkldnn.hpp
@@ -0,0 +1,4252 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_HPP
+#define MKLDNN_HPP
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "mkldnn.h"
+#endif
+
+namespace mkldnn {
+
+/// @addtogroup cpp_api C++ API
+/// @{
+
+/// @addtogroup cpp_api_utils Utils
+/// @{
+
+/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
+template
+class handle_traits {};
+
+/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
+/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
+/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
+/// can be passed by value. This class enables wrapping:
+/// - Newly constructed handles.
+/// @n In this case, the constructed handle uses reference counting provided
+/// by @p std::shared_ptr with a proper deleter function specified through
+/// the @p handle_traits class.
+/// - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
+/// example, through #mkldnn_primitive_get_output()).
+/// @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
+/// deleter because it is assumed that the handle wrapper for the original
+/// object deletes the handle (this model is similar to @p std::weak_ptr).
+template >
+class handle {
+private:
+ std::shared_ptr::type> _data;
+ handle(const handle &&) = delete;
+ handle &operator=(const handle &&other) = delete;
+
+protected:
+ /// Constructs a C handle wrapper.
+ /// @param t The C handle to wrap.
+ /// @param weak A flag to specify whether to construct a weak wrapper.
+ handle(T t = 0, bool weak = false) : _data(0) { reset(t, weak); }
+
+ bool operator==(const T other) const { return other == _data.get(); }
+ bool operator!=(const T other) const { return !(*this == other); }
+
+public:
+ handle(const handle &other) : _data(other._data) {}
+ handle &operator=(const handle &other) {
+ _data = other._data;
+ return *this;
+ }
+ /// Resets the value of a C handle.
+ /// @param t The new value of the C handle.
+ /// @param weak A flag to specify whether the wrapper should be weak.
+ void reset(T t, bool weak = false) {
+ auto dummy_destructor = [](T) {
+ return decltype(traits::destructor(0))(0);
+ };
+ _data.reset(t, weak ? dummy_destructor : traits::destructor);
+ }
+
+ /// Returns the value of the underlying C handle.
+ T get() const { return _data.get(); }
+
+ bool operator==(const handle &other) const {
+ return other._data.get() == _data.get();
+ }
+ bool operator!=(const handle &other) const { return !(*this == other); }
+};
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <>
+struct handle_traits {
+ static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
+};
+
+template <>
+struct handle_traits {
+ static constexpr auto destructor = &mkldnn_primitive_destroy;
+};
+#endif
+
+/// Base class for all computational primitives.
+class primitive : public handle {
+ friend struct error;
+ friend struct stream;
+ friend class primitive_at;
+ using handle::handle;
+
+public:
+ /// A proxy to C primitive kind enum
+ enum class kind {
+ undefined_primitive = mkldnn_undefined_primitive,
+ memory = mkldnn_memory,
+ view = mkldnn_view,
+ reorder = mkldnn_reorder,
+ concat = mkldnn_concat,
+ concat_inplace = mkldnn_concat_inplace,
+ sum = mkldnn_sum,
+ convolution = mkldnn_convolution,
+ deconvolution = mkldnn_deconvolution,
+ eltwise = mkldnn_eltwise,
+ relu = mkldnn_relu,
+ softmax = mkldnn_softmax,
+ pooling = mkldnn_pooling,
+ lrn = mkldnn_lrn,
+ batch_normalization = mkldnn_batch_normalization,
+ inner_product = mkldnn_inner_product,
+ convolution_relu = mkldnn_convolution_relu,
+ rnn = mkldnn_rnn,
+ };
+
+ /// A wrapper structure to specify a particular output of a primitive.
+ struct at {
+ /// The underlying C API structure.
+ mkldnn_primitive_at_t data;
+ /// Constructs a wrapper specifying @p aprimitive output with index @p
+ /// at.
+ ///
+ /// @param aprimitive The target primitive.
+ /// @param at The output index.
+
+ at(const primitive &aprimitive, size_t at = 0)
+ : data(mkldnn_primitive_at(aprimitive.get(), at)) {}
+ /// Returns the specified output.
+ inline operator primitive() const;
+ };
+
+ /// Returns the descriptor of the underlying C API primitive
+ inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
+ // TODO: use the C++ API wrapper structure.
+};
+
+inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
+ return static_cast(akind);
+}
+
+/// Intel(R) MKL-DNN exception class.
+///
+/// This class captures the status returned by the failed C API function, error
+/// message, and, optionally, handle of the primitive that caused the error.
+struct error : public std::exception {
+ mkldnn_status_t status;
+ std::string message;
+ primitive error_primitive;
+
+ /// Constructs an error instance.
+ ///
+ /// @param astatus The error status returned by the C API.
+ /// @param amessage The error message.
+ /// @param aerror_primitive (optional) A C handle of the primitive that
+ /// caused the error.
+
+ error(mkldnn_status_t astatus,
+ std::string amessage,
+ mkldnn_primitive_t aerror_primitive = 0)
+ : status(astatus),
+ message(amessage),
+ error_primitive(aerror_primitive, true) {}
+
+ /// A convenience function for wrapping calls to the C API. Checks the
+ /// return status and throws an #error in case of failure.
+ ///
+ /// @param status The error status returned by the C API.
+ /// @param message The error message.
+ /// @param error_primitive (optional) A C handle of the primitive that
+ /// caused the error.
+
+ static void wrap_c_api(mkldnn_status_t status,
+ std::string message,
+ mkldnn_primitive_t *error_primitive = 0) {
+ if (status != mkldnn_success) {
+ if (nullptr != error_primitive)
+ throw error(status, message, *error_primitive);
+ else
+ throw error(status, message, nullptr);
+ }
+ }
+};
+
+inline primitive::at::operator primitive() const {
+ const_mkldnn_primitive_t output;
+ error::wrap_c_api(
+ mkldnn_primitive_get_output(data.primitive, data.output_index, &output),
+ "could not get an output primitive");
+ return primitive(const_cast(output), true);
+}
+
+const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
+ const_mkldnn_primitive_desc_t pd;
+ error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
+ "could not get primitive descriptor by primitive");
+ return pd;
+}
+/// @}
+
+/// @addtogroup cpp_api_enums Common data types and enumerations
+/// @{
+
+enum round_mode {
+ round_nearest = mkldnn_round_nearest,
+ round_down = mkldnn_round_down,
+};
+
+inline mkldnn_round_mode_t convert_to_c(round_mode mode) {
+ return static_cast(mode);
+}
+
+enum padding_kind { zero = mkldnn_padding_zero };
+
+inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
+ return static_cast(kind);
+}
+
+enum prop_kind {
+ forward_training = mkldnn_forward_training,
+ forward_scoring = mkldnn_forward_scoring,
+ forward_inference = mkldnn_forward_inference,
+ forward = mkldnn_forward,
+ backward = mkldnn_backward,
+ backward_data = mkldnn_backward_data,
+ backward_weights = mkldnn_backward_weights,
+ backward_bias = mkldnn_backward_bias
+};
+
+inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
+ return static_cast(kind);
+}
+
+enum algorithm {
+ algorithm_undef = mkldnn_alg_kind_undef,
+ convolution_direct = mkldnn_convolution_direct,
+ convolution_winograd = mkldnn_convolution_winograd,
+ deconvolution_direct = mkldnn_deconvolution_direct,
+ deconvolution_winograd = mkldnn_deconvolution_winograd,
+ eltwise_relu = mkldnn_eltwise_relu,
+ eltwise_tanh = mkldnn_eltwise_tanh,
+ eltwise_elu = mkldnn_eltwise_elu,
+ eltwise_square = mkldnn_eltwise_square,
+ eltwise_abs = mkldnn_eltwise_abs,
+ eltwise_sqrt = mkldnn_eltwise_sqrt,
+ eltwise_linear = mkldnn_eltwise_linear,
+ eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
+ eltwise_soft_relu = mkldnn_eltwise_soft_relu,
+ eltwise_logistic = mkldnn_eltwise_logistic,
+ lrn_across_channels = mkldnn_lrn_across_channels,
+ lrn_within_channel = mkldnn_lrn_within_channel,
+ pooling_max = mkldnn_pooling_max,
+ pooling_avg = mkldnn_pooling_avg,
+ pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
+ pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
+ vanilla_rnn = mkldnn_vanilla_rnn,
+ vanilla_lstm = mkldnn_vanilla_lstm,
+ vanilla_gru = mkldnn_vanilla_gru,
+};
+
+inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
+ return static_cast(aalgorithm);
+}
+
+enum batch_normalization_flag {
+ use_global_stats = mkldnn_use_global_stats,
+ use_scale_shift = mkldnn_use_scaleshift,
+ omit_stats = mkldnn_omit_stats,
+ fuse_bn_relu = mkldnn_fuse_bn_relu
+};
+
+inline mkldnn_batch_normalization_flag_t convert_to_c(
+ batch_normalization_flag aflag) {
+ return static_cast(aflag);
+}
+
+enum rnn_direction {
+ unidirectional_left2right = mkldnn_unidirectional_left2right,
+ unidirectional_right2left = mkldnn_unidirectional_right2left,
+ unidirectional = mkldnn_unidirectional,
+ bidirectional_concat = mkldnn_bidirectional_concat,
+ bidirectional_sum = mkldnn_bidirectional_sum,
+};
+
+inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
+ return static_cast(adir);
+}
+
+enum query {
+ undef = mkldnn_query_undef,
+
+ eengine = mkldnn_query_engine,
+ primitive_kind = mkldnn_query_primitive_kind,
+
+ num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
+ num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
+
+ time_estimate_f64 = mkldnn_query_time_estimate_f64,
+ memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
+
+ impl_info_str = mkldnn_query_impl_info_str,
+
+ memory_d = mkldnn_query_memory_d,
+ convolution_d = mkldnn_query_convolution_d,
+ deconvolution_d = mkldnn_query_deconvolution_d,
+ eltwise_d = mkldnn_query_eltwise_d,
+ relu_d = mkldnn_query_relu_d,
+ softmax_d = mkldnn_query_softmax_d,
+ pooling_d = mkldnn_query_pooling_d,
+ lrn_d = mkldnn_query_lrn_d,
+ batch_normalization_d = mkldnn_query_batch_normalization_d,
+ inner_product_d = mkldnn_query_inner_product_d,
+ convolution_relu_d = mkldnn_query_convolution_relu_d,
+ rnn_d = mkldnn_query_rnn_d,
+
+ input_pd = mkldnn_query_input_pd,
+ output_pd = mkldnn_query_output_pd,
+ src_pd = mkldnn_query_src_pd,
+ diff_src_pd = mkldnn_query_diff_src_pd,
+ weights_pd = mkldnn_query_weights_pd,
+ diff_weights_pd = mkldnn_query_diff_weights_pd,
+ dst_pd = mkldnn_query_dst_pd,
+ diff_dst_pd = mkldnn_query_diff_dst_pd,
+ workspace_pd = mkldnn_query_workspace_pd,
+};
+
+inline mkldnn_query_t convert_to_c(query aquery) {
+ return static_cast(aquery);
+}
+
+/// @}
+
+/// @addtogroup cpp_api_attr Attributes
+/// @{
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <>
+struct handle_traits {
+ static constexpr auto destructor = &mkldnn_post_ops_destroy;
+};
+#endif
+
+struct post_ops : public handle {
+ post_ops() {
+ mkldnn_post_ops_t result;
+ error::wrap_c_api(mkldnn_post_ops_create(&result),
+ "could not create post operation sequence");
+ reset(result);
+ }
+
+ int len() const { return mkldnn_post_ops_len(get()); }
+
+ primitive::kind kind(int index) const {
+ error::wrap_c_api(index < len() ? mkldnn_success : mkldnn_invalid_arguments,
+ "post_ops index is out of range");
+ return static_cast(mkldnn_post_ops_get_kind(get(), index));
+ }
+
+ void append_sum(float scale = 1.) {
+ error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
+ "could not append sum");
+ }
+
+ void get_params_sum(int index, float &scale) const {
+ error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
+ "could not get sum params");
+ }
+
+ void append_eltwise(float scale, algorithm alg, float alpha, float beta) {
+ error::wrap_c_api(mkldnn_post_ops_append_eltwise(
+ get(), scale, convert_to_c(alg), alpha, beta),
+ "could not append eltwise");
+ }
+
+ void get_params_eltwise(int index,
+ float &scale,
+ algorithm &alg,
+ float &alpha,
+ float &beta) const {
+ mkldnn_alg_kind_t c_alg;
+ error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(
+ get(), index, &scale, &c_alg, &alpha, &beta),
+ "could not get eltwise params");
+ alg = static_cast(c_alg);
+ }
+};
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <>
+struct handle_traits {
+ static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
+};
+#endif
+
+struct primitive_attr : public handle {
+ primitive_attr() {
+ mkldnn_primitive_attr_t result;
+ error::wrap_c_api(mkldnn_primitive_attr_create(&result),
+ "could not create a primitive attr");
+ reset(result);
+ }
+
+ round_mode get_int_output_round_mode() const {
+ mkldnn_round_mode_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_attr_get_int_output_round_mode(get(), &result),
+ "could not get int output round mode");
+ return round_mode(result);
+ }
+
+ void set_int_output_round_mode(round_mode mode) {
+ error::wrap_c_api(mkldnn_primitive_attr_set_int_output_round_mode(
+ get(), mkldnn::convert_to_c(mode)),
+ "could not set int output round mode");
+ }
+
+ void get_output_scales(int &mask, std::vector &scales) const {
+ int count, c_mask;
+ const float *c_scales;
+ error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(
+ get(), &count, &c_mask, &c_scales),
+ "could not get int output scales");
+ scales.resize(count);
+
+ mask = c_mask;
+ for (int c = 0; c < count; ++c) scales[c] = c_scales[c];
+ }
+
+ void set_output_scales(int mask, const std::vector &scales) {
+ error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(
+ get(), (int)scales.size(), mask, &scales[0]),
+ "could not set int output scales");
+ }
+
+ const post_ops get_post_ops() const {
+ post_ops result;
+ const_mkldnn_post_ops_t c_result;
+ error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
+ "could not get post operation sequence");
+ result.reset(const_cast(c_result), true);
+ return result;
+ }
+
+ void set_post_ops(post_ops ops) {
+ error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
+ "could not set post operation sequence");
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_engine Engine
+/// @{
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <>
+struct handle_traits {
+ static constexpr auto destructor = &mkldnn_engine_destroy;
+};
+#endif
+
+/// An execution engine.
+struct engine : public handle {
+ friend class primitive;
+ // gcc bug??? using handle::handle;
+
+ /// Kinds of engines
+ enum kind {
+ /// An unspecified engine
+ any = mkldnn_any_engine,
+ /// CPU engine
+ cpu = mkldnn_cpu,
+ };
+
+ /// Returns the number of engines of a certain kind.
+ ///
+ /// @param akind The kind of engines to count.
+
+ static size_t get_count(kind akind) {
+ return mkldnn_engine_get_count(convert_to_c(akind));
+ }
+
+ /// Constructs an engine.
+ ///
+ /// @param akind The kind of engine to construct.
+ /// @param index The index of the engine. Must be less than the value
+ /// returned by #get_count() for this particular kind of engine.
+
+ engine(kind akind, size_t index) {
+ mkldnn_engine_t aengine;
+ error::wrap_c_api(
+ mkldnn_engine_create(&aengine, convert_to_c(akind), index),
+ "could not create an engine");
+ reset(aengine);
+ }
+
+ explicit engine(const mkldnn_engine_t &aengine) : handle(aengine, true) {}
+
+ engine(const handle &pd) {
+ mkldnn_engine_t engine_q;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_query(
+ pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q),
+ "could not get engine from primitive_desc");
+ reset(engine_q, true);
+ }
+
+ template
+ static engine query(const primitive_desc &pd) {
+ mkldnn_engine_t engine_q;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_query(
+ pd.get(), mkldnn::convert_to_c(eengine), 0, &engine_q),
+ "could not get engine from primitive_desc");
+
+ return engine(engine_q);
+ }
+
+private:
+ static mkldnn_engine_kind_t convert_to_c(kind akind) {
+ return static_cast(akind);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_primitives Primitives
+/// @{
+
+/// @addtogroup cpp_api_memory Memory
+/// @{
+
+/// Memory primitive that describes the data.
+struct memory : public primitive {
+private:
+ std::shared_ptr _handle;
+
+public:
+ typedef std::vector::type> dims;
+
+ template
+ static void validate_dims(std::vector v) {
+ if (v.size() > TENSOR_MAX_DIMS)
+ throw error(mkldnn_invalid_arguments, "invalid dimensions");
+ }
+
+ /// Data type specification. See #mkldnn_data_type_t for a detailed
+ /// description.
+ enum data_type {
+ data_undef = mkldnn_data_type_undef,
+ f32 = mkldnn_f32,
+ s32 = mkldnn_s32,
+ s16 = mkldnn_s16,
+ s8 = mkldnn_s8,
+ u8 = mkldnn_u8,
+ };
+
+ /// Memory format specification. See #mkldnn_memory_format_t
+ /// for a detailed description.
+ enum format {
+ format_undef = mkldnn_format_undef,
+ any = mkldnn_any,
+ blocked = mkldnn_blocked,
+ x = mkldnn_x,
+ nc = mkldnn_nc,
+ nchw = mkldnn_nchw,
+ nhwc = mkldnn_nhwc,
+ chwn = mkldnn_chwn,
+ nChw8c = mkldnn_nChw8c,
+ nChw16c = mkldnn_nChw16c,
+ ncdhw = mkldnn_ncdhw,
+ ndhwc = mkldnn_ndhwc,
+ nCdhw16c = mkldnn_nCdhw16c,
+ oi = mkldnn_oi,
+ io = mkldnn_io,
+ oihw = mkldnn_oihw,
+ ihwo = mkldnn_ihwo,
+ hwio = mkldnn_hwio,
+ oidhw = mkldnn_oidhw,
+ OIdhw16i16o = mkldnn_OIdhw16i16o,
+ OIdhw16o16i = mkldnn_OIdhw16o16i,
+ Oidhw16o = mkldnn_Oidhw16o,
+ Odhwi16o = mkldnn_Odhwi16o,
+ oIhw8i = mkldnn_oIhw8i,
+ oIhw16i = mkldnn_oIhw16i,
+ OIhw8i8o = mkldnn_OIhw8i8o,
+ OIhw16i16o = mkldnn_OIhw16i16o,
+ OIhw8o8i = mkldnn_OIhw8o8i,
+ OIhw16o16i = mkldnn_OIhw16o16i,
+ IOhw16o16i = mkldnn_IOhw16o16i,
+ OIhw8i16o2i = mkldnn_OIhw8i16o2i,
+ OIhw8o16i2o = mkldnn_OIhw8o16i2o,
+ OIhw4i16o4i = mkldnn_OIhw4i16o4i,
+ Oihw8o = mkldnn_Oihw8o,
+ Oihw16o = mkldnn_Oihw16o,
+ Ohwi8o = mkldnn_Ohwi8o,
+ Ohwi16o = mkldnn_Ohwi16o,
+ OhIw16o4i = mkldnn_OhIw16o4i,
+ goihw = mkldnn_goihw,
+ hwigo = mkldnn_hwigo,
+ gOIhw8i8o = mkldnn_gOIhw8i8o,
+ gOIhw16i16o = mkldnn_gOIhw16i16o,
+ gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
+ gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
+ gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
+ gOihw8o = mkldnn_gOihw8o,
+ gOihw16o = mkldnn_gOihw16o,
+ gOhwi8o = mkldnn_gOhwi8o,
+ gOhwi16o = mkldnn_gOhwi16o,
+ Goihw8g = mkldnn_Goihw8g,
+ Goihw16g = mkldnn_Goihw16g,
+ gOIhw8o8i = mkldnn_gOIhw8o8i,
+ gOIhw16o16i = mkldnn_gOIhw16o16i,
+ gIOhw16o16i = mkldnn_gIOhw16o16i,
+ gOhIw16o4i = mkldnn_gOhIw16o4i,
+ goidhw = mkldnn_goidhw,
+ gOIdhw16i16o = mkldnn_gOIdhw16i16o,
+ gOIdhw16o16i = mkldnn_gOIdhw16o16i,
+ gOidhw16o = mkldnn_gOidhw16o,
+ gOdhwi16o = mkldnn_gOdhwi16o,
+ ntc = mkldnn_ntc,
+ tnc = mkldnn_tnc,
+ ldsnc = mkldnn_ldsnc,
+ ldigo = mkldnn_ldigo,
+ ldigo_p = mkldnn_ldigo_p,
+ ldgoi = mkldnn_ldgoi,
+ ldgoi_p = mkldnn_ldgoi_p,
+ ldgo = mkldnn_ldgo,
+ wino_fmt = mkldnn_wino_fmt,
+ format_last = mkldnn_format_last,
+ };
+
+ /// A memory descriptor.
+ struct desc {
+ friend struct memory;
+ /// The underlying C API data structure.
+ mkldnn_memory_desc_t data;
+
+ /// Constructs a memory descriptor.
+ ///
+ /// @param adims Data dimensions
+ /// @param adata_type Data precision/type.
+ /// @param aformat Data layout format.
+ desc(dims adims, data_type adata_type, format aformat) {
+ validate_dims(adims);
+ error::wrap_c_api(
+ mkldnn_memory_desc_init(&data,
+ (int)adims.size(),
+ adims.size() == 0 ? nullptr : &adims[0],
+ convert_to_c(adata_type),
+ convert_to_c(aformat)),
+ "could not initialize a memory descriptor");
+ }
+
+ /// Constructs a memory descriptor from a C API data structure.
+ ///
+ /// @param adata A C API #mkldnn_memory_desc_t structure.
+ desc(const mkldnn_memory_desc_t &adata) : data(adata) {}
+ };
+
+ /// A memory primitive descriptor.
+ struct primitive_desc : public handle {
+ friend struct memory;
+
+ // TODO: make private
+ primitive_desc() {}
+
+ /// Constructs a memory primitive descriptor.
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(mkldnn_memory_primitive_desc_create(
+ &result, &adesc.data, aengine.get()),
+ "could not initialize a memory primitive descriptor");
+ reset(result);
+ }
+
+ /// Returns the memory primitive descriptor.
+ memory::desc desc() {
+ auto memory_d = mkldnn_primitive_desc_query_memory_d(get());
+ return memory::desc(*memory_d);
+ }
+
+ /// Returns the number of bytes required to allocate the memory described
+ /// including the padding area.
+ size_t get_size() const {
+ return mkldnn_memory_primitive_desc_get_size(get());
+ }
+
+ bool operator==(const primitive_desc &other) const {
+ return mkldnn_memory_primitive_desc_equal(get(), other.get());
+ }
+
+ bool operator!=(const primitive_desc &other) const {
+ return !operator==(other);
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ /// Constructs a memory primitive from a generic primitive.
+ ///
+ /// @param aprimitive The primitive to treat as memory.
+ memory(const primitive &aprimitive) : primitive(aprimitive) {}
+ /// Constructs a memory primitive.
+ ///
+ /// @param adesc Memory primitive descriptor.
+ memory(const primitive_desc &adesc) {
+ mkldnn_primitive_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
+ "could not create a memory primitive");
+ reset(result);
+ auto _malloc = [](size_t size, int alignment) {
+ void *ptr;
+#ifdef _WIN32
+ ptr = _aligned_malloc(size, alignment);
+ int rc = ((ptr) ? 0 : errno);
+#else
+ int rc = ::posix_memalign(&ptr, alignment, size);
+#endif /* _WIN32 */
+ return (rc == 0) ? (char *)ptr : nullptr;
+ };
+ auto _free = [](char *p) {
+#ifdef _WIN32
+ _aligned_free((void *)p);
+#else
+ ::free((void *)p);
+#endif /* _WIN32 */
+ };
+ _handle.reset(_malloc(adesc.get_size(), 4096), _free);
+ set_data_handle(_handle.get());
+ }
+
+ memory(const primitive_desc &adesc, void *ahandle) {
+ mkldnn_primitive_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, adesc.get(), nullptr, nullptr),
+ "could not create a memory primitive");
+ reset(result);
+ set_data_handle(ahandle);
+ }
+
+ /// Returns the descriptor of the memory primitive.
+ primitive_desc get_primitive_desc() const {
+ primitive_desc adesc;
+ const_mkldnn_primitive_desc_t cdesc;
+ error::wrap_c_api(
+ mkldnn_primitive_get_primitive_desc(get(), &cdesc),
+ "could not get primitive descriptor from a memory primitive");
+ /* FIXME: no const_cast should be here */
+ adesc.reset(const_cast(cdesc), true);
+ return adesc;
+ }
+
+ /// Returns a handle of the data contained in the memory primitive. On
+ /// the CPU engine, this is a pointer to the allocated memory.
+ inline void *get_data_handle() const {
+ void *handle;
+ error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
+ "could not get native handle");
+ return handle;
+ }
+
+ inline void set_data_handle(void *handle) const {
+ error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
+ "could not set native handle");
+ }
+
+ // Must go away or be private:
+ static mkldnn_data_type_t convert_to_c(data_type adata_type) {
+ return static_cast(adata_type);
+ }
+ static mkldnn_memory_format_t convert_to_c(format aformat) {
+ return static_cast(aformat);
+ }
+};
+
+inline memory::desc zero_md() {
+ mkldnn_memory_desc_t zero;
+ zero.primitive_kind = mkldnn_memory;
+ return memory::desc(zero);
+}
+
+inline memory null_memory(engine eng) {
+ mkldnn::memory::desc zero = zero_md();
+ return memory({zero, eng}, nullptr);
+}
+
+inline bool is_null_memory(const const_mkldnn_primitive_t &aprimitive) {
+ const_mkldnn_primitive_desc_t aprimitive_pd;
+ mkldnn_primitive_get_primitive_desc(aprimitive, &aprimitive_pd);
+ const mkldnn_memory_desc_t *aprimitive_md =
+ mkldnn_primitive_desc_query_memory_d(aprimitive_pd);
+
+ return ((aprimitive_md != nullptr) && (aprimitive_md->ndims == 0));
+}
+
+inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
+ return a == memory::convert_to_c(b);
+}
+inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
+ return !(a == b);
+}
+inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
+ return b == a;
+}
+inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
+ return !(a == b);
+}
+
+inline bool operator==(mkldnn_memory_format_t a, memory::format b) {
+ return a == memory::convert_to_c(b);
+}
+inline bool operator!=(mkldnn_memory_format_t a, memory::format b) {
+ return !(a == b);
+}
+inline bool operator==(memory::format a, mkldnn_memory_format_t b) {
+ return b == a;
+}
+inline bool operator!=(memory::format a, mkldnn_memory_format_t b) {
+ return !(a == b);
+}
+
+/// @}
+
+/// @addtogroup cpp_api_reorder Reorder
+/// @{
+
+struct reorder : public primitive {
+ struct primitive_desc : public handle {
+ primitive_desc(const memory::primitive_desc &input,
+ const memory::primitive_desc &output) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(mkldnn_reorder_primitive_desc_create(
+ &result, input.get(), output.get()),
+ "could not create a reorder primitive descriptor");
+ reset(result);
+ }
+
+ primitive_desc(const memory::primitive_desc &input,
+ const memory::primitive_desc &output,
+ const primitive_attr &aattr) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(mkldnn_reorder_primitive_desc_create_v2(
+ &result, input.get(), output.get(), aattr.get()),
+ "could not create a reorder primitive descriptor");
+ reset(result);
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ reorder(const primitive_desc &aprimitive_desc,
+ const primitive::at &input,
+ const memory &output) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {input.data};
+ const_mkldnn_primitive_t outputs[] = {output.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a reorder primitive");
+ reset(result);
+ }
+
+ reorder(const primitive::at &input, const memory &output) {
+ auto input_mpd = memory(input).get_primitive_desc();
+ auto output_mpd = output.get_primitive_desc();
+
+ auto reorder_d = primitive_desc(input_mpd, output_mpd);
+
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {input.data};
+ const_mkldnn_primitive_t outputs[] = {output.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, reorder_d.get(), inputs, outputs),
+ "could not create a reorder primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_view View
+/// @{
+
+struct view : public primitive {
+ struct primitive_desc : public handle {
+ primitive_desc(const memory::primitive_desc &input,
+ memory::dims dims,
+ memory::dims offsets) {
+ mkldnn_primitive_desc_t result;
+
+ error::wrap_c_api(mkldnn_view_primitive_desc_create(
+ &result, input.get(), &dims[0], &offsets[0]),
+ "could not create a view primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ view(const primitive_desc &view_pd, primitive::at input) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {input.data};
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr),
+ "could not create a view primitive");
+ reset(result);
+ }
+
+ view(memory input, memory::dims dims, memory::dims offsets) {
+ mkldnn_primitive_t result;
+ primitive_desc view_pd(input.get_primitive_desc(), dims, offsets);
+ mkldnn_primitive_at_t inputs[] = {primitive::at(input).data};
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, view_pd.get(), inputs, nullptr),
+ "could not create a view primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_concat Concat
+/// @{
+
+struct concat : public primitive {
+ struct primitive_desc : public handle {
+ std::vector cpp_to_c(
+ std::vector inputs) {
+ std::vector c_api_inputs;
+ c_api_inputs.reserve(inputs.size());
+ auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
+ std::transform(inputs.begin(),
+ inputs.end(),
+ std::back_inserter(c_api_inputs),
+ convert_to_c);
+ return c_api_inputs;
+ }
+
+ primitive_desc(const memory::desc &output,
+ int concat_dimension,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+
+ error::wrap_c_api(
+ mkldnn_concat_primitive_desc_create(&result,
+ &output.data,
+ (int)c_api_inputs.size(),
+ concat_dimension,
+ &c_api_inputs[0]),
+ "could not create a concat primitive descriptor");
+ reset(result);
+ }
+
+ primitive_desc(int concat_dimension,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+
+ error::wrap_c_api(
+ mkldnn_concat_primitive_desc_create(&result,
+ nullptr,
+ (int)c_api_inputs.size(),
+ concat_dimension,
+ &c_api_inputs[0]),
+ "could not create a concat primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ concat(const primitive_desc &concat_pd,
+ std::vector &inputs,
+ const memory &output) {
+ mkldnn_primitive_t result;
+
+ std::vector p_inputs;
+ for (size_t i = 0; i < inputs.size(); i++)
+ p_inputs.push_back(inputs[i].data);
+ const_mkldnn_primitive_t outputs[] = {output.get()};
+
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, concat_pd.get(), &p_inputs[0], outputs),
+ "could not create a concat primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_sum Sum
+/// @{
+
+struct sum : public primitive {
+ struct primitive_desc : public handle {
+ std::vector cpp_to_c(
+ std::vector inputs) {
+ std::vector c_api_inputs;
+ c_api_inputs.reserve(inputs.size());
+ auto convert_to_c = [](memory::primitive_desc d) { return d.get(); };
+ std::transform(inputs.begin(),
+ inputs.end(),
+ std::back_inserter(c_api_inputs),
+ convert_to_c);
+ return c_api_inputs;
+ }
+
+ primitive_desc(const memory::desc &output,
+ const std::vector &scales,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+
+ error::wrap_c_api(
+ mkldnn_sum_primitive_desc_create(&result,
+ &output.data,
+ (int)c_api_inputs.size(),
+ &scales[0],
+ &c_api_inputs[0]),
+ "could not create a sum primitive descriptor");
+ reset(result);
+ }
+
+ primitive_desc(const std::vector &scales,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+
+ error::wrap_c_api(
+ mkldnn_sum_primitive_desc_create(&result,
+ nullptr,
+ (int)c_api_inputs.size(),
+ &scales[0],
+ &c_api_inputs[0]),
+ "could not create a sum primitive descriptor");
+ reset(result);
+ }
+
+ /** @deprecated: api backwards compatibility for double scales type */
+ MKLDNN_DEPRECATED
+ primitive_desc(const memory::desc &output,
+ std::vector scale,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+ auto scale_f = scale_to_float(scale);
+
+ error::wrap_c_api(
+ mkldnn_sum_primitive_desc_create(&result,
+ &output.data,
+ (int)c_api_inputs.size(),
+ &scale_f[0],
+ &c_api_inputs[0]),
+ "could not create a sum primitive descriptor");
+ reset(result);
+ }
+
+ /** @deprecated: api backwards compatibility for double scales type */
+ MKLDNN_DEPRECATED
+ primitive_desc(std::vector scale,
+ std::vector inputs) {
+ mkldnn_primitive_desc_t result;
+
+ auto c_api_inputs = cpp_to_c(inputs);
+ auto scale_f = scale_to_float(scale);
+
+ error::wrap_c_api(
+ mkldnn_sum_primitive_desc_create(&result,
+ nullptr,
+ (int)c_api_inputs.size(),
+ &scale_f[0],
+ &c_api_inputs[0]),
+ "could not create a sum primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ sum(const primitive_desc &sum_pd,
+ std::vector &inputs,
+ const memory &output) {
+ mkldnn_primitive_t result;
+
+ std::vector p_inputs;
+ for (size_t i = 0; i < inputs.size(); i++)
+ p_inputs.push_back(inputs[i].data);
+ const_mkldnn_primitive_t outputs[] = {output.get()};
+
+ error::wrap_c_api(
+ mkldnn_primitive_create(&result, sum_pd.get(), &p_inputs[0], outputs),
+ "could not create a sum primitive");
+ reset(result);
+ }
+
+private:
+ static std::vector scale_to_float(const std::vector &vd) {
+ std::vector vf(vd.size());
+ std::transform(
+ vd.begin(), vd.end(), vf.begin(), [=](double x) { return (float)x; });
+ return vf;
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_convolution Convolution
+/// @{
+
+struct convolution_forward : public primitive {
+ struct desc {
+ mkldnn_convolution_desc_t data;
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &bias_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(mkldnn_convolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ &bias_desc.data,
+ &dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution forward descriptor");
+ }
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(mkldnn_convolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ nullptr,
+ &dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution forward descriptor");
+ }
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &bias_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_convolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ &bias_desc.data,
+ &dst_desc.data,
+ &strides[0],
+ &dilates[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a dilated convolution forward descriptor");
+ }
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_convolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ nullptr,
+ &dst_desc.data,
+ &strides[0],
+ &dilates[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a dilated convolution forward descriptor");
+ }
+ };
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a convolution forward primitive descriptor");
+ reset(result);
+ }
+
+ primitive_desc(const desc &adesc,
+ const primitive_attr &aattr,
+ const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create_v2(
+ &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
+ "could not create a convolution forward primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc bias_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 1);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a bias primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ convolution_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const primitive::at &bias,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution forward bias primitive");
+ reset(result);
+ }
+
+ convolution_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution forward primitive");
+ reset(result);
+ }
+};
+
+struct convolution_backward_data : public primitive {
+ struct desc {
+ mkldnn_convolution_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &diff_src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_convolution_backward_data_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &diff_src_desc.data,
+ &weights_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward data descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &diff_src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_convolution_backward_data_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &diff_src_desc.data,
+ &weights_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &dilates[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward data descriptor");
+ }
+ };
+ struct primitive_desc : public handle {
+ primitive_desc(
+ const desc &adesc,
+ const engine &aengine,
+ const convolution_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a convolution backward data primitive descriptor");
+ reset(result);
+ }
+ memory::primitive_desc diff_src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ convolution_backward_data(const primitive_desc &aprimitive_desc,
+ const primitive::at &diff_dst,
+ const primitive::at &weights,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution backward data primitive");
+ reset(result);
+ }
+};
+
+struct convolution_backward_weights : public primitive {
+ struct desc {
+ mkldnn_convolution_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_bias_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_convolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ &diff_bias_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward weights descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_convolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ nullptr,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward weights descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_bias_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_convolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ &diff_bias_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &dilates[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward weights descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims dilates,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(dilates);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_dilated_convolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ nullptr,
+ &diff_dst_desc.data,
+ &strides[0],
+ &dilates[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a convolution backward weights descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(
+ const desc &adesc,
+ const engine &aengine,
+ const convolution_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a convolution backward weights primitive "
+ "descriptor");
+ reset(result);
+ }
+ memory::primitive_desc src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_bias_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_weights_pd), 1);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_bias primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ convolution_backward_weights(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const memory &diff_weights,
+ const memory &diff_bias) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution backward weights primitive");
+ reset(result);
+ }
+ convolution_backward_weights(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const memory &diff_weights) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_weights.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution backward weights primitive");
+ reset(result);
+ }
+};
+
+struct convolution_relu_forward : public primitive {
+ struct desc {
+ mkldnn_convolution_relu_desc_t data;
+ desc(const convolution_forward::desc conv_desc,
+ const float negative_slope) {
+ error::wrap_c_api(
+ mkldnn_convolution_relu_desc_init(
+ &data, &conv_desc.data, negative_slope),
+ "could not create a convolution_relu_forward descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a convolution relu forward descriptor");
+ reset(result);
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ convolution_relu_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const primitive::at &bias,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution relu forward primitive");
+ reset(result);
+ }
+
+ convolution_relu_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a convolution relu forward primitive");
+ reset(result);
+ }
+};
+
+/// @}
+//
+/// @addtogroup cpp_api_deconvolution Deconvolution
+/// @{
+
+struct deconvolution_forward : public primitive {
+ struct desc {
+ mkldnn_deconvolution_desc_t data;
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &bias_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ &bias_desc.data,
+ &dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a deconvolution forward descriptor");
+ }
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(
+ &data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &weights_desc.data,
+ nullptr,
+ &dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a deconvolution forward descriptor");
+ }
+ };
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a deconvolution forward primitive descriptor");
+ reset(result);
+ }
+
+ primitive_desc(const desc &adesc,
+ const primitive_attr &aattr,
+ const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create_v2(
+ &result, &adesc.data, aattr.get(), aengine.get(), nullptr),
+ "could not create a deconvolution forward primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc bias_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 1);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a bias primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ deconvolution_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const primitive::at &bias,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data, bias.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a deconvolution forward bias primitive");
+ reset(result);
+ }
+
+ deconvolution_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &weights,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, weights.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a deconvolution forward primitive");
+ reset(result);
+ }
+};
+
+struct deconvolution_backward_data : public primitive {
+ struct desc {
+ mkldnn_deconvolution_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &diff_src_desc,
+ const memory::desc &weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_deconvolution_backward_data_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &diff_src_desc.data,
+ &weights_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a deconvolution backward data descriptor");
+ }
+ };
+ struct primitive_desc : public handle {
+ primitive_desc(
+ const desc &adesc,
+ const engine &aengine,
+ const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a deconvolution backward data primitive "
+ "descriptor");
+ reset(result);
+ }
+ memory::primitive_desc diff_src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ deconvolution_backward_data(const primitive_desc &aprimitive_desc,
+ const primitive::at &diff_dst,
+ const primitive::at &weights,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {diff_dst.data, weights.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a deconvolution backward data primitive");
+ reset(result);
+ }
+};
+
+struct deconvolution_backward_weights : public primitive {
+ struct desc {
+ mkldnn_deconvolution_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_bias_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_deconvolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ &diff_bias_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a deconvolution backward weights descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &diff_weights_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims strides,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_deconvolution_backward_weights_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &diff_weights_desc.data,
+ nullptr,
+ &diff_dst_desc.data,
+ &strides[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not create a deconvolution backward weights descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(
+ const desc &adesc,
+ const engine &aengine,
+ const deconvolution_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a deconvolution backward weights primitive "
+ "descriptor");
+ reset(result);
+ }
+ memory::primitive_desc src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a src primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_weights_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_weights_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_weights primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_bias_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_weights_pd), 1);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_bias primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const memory &diff_weights,
+ const memory &diff_bias) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_weights.get(), diff_bias.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a deconvolution backward weights primitive");
+ reset(result);
+ }
+ deconvolution_backward_weights(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const memory &diff_weights) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_weights.get()};
+ error::wrap_c_api(
+ mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a deconvolution backward weights primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_lrn LRN
+/// @{
+
+struct lrn_forward : public primitive {
+ struct desc {
+ mkldnn_lrn_desc_t data;
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ int local_size,
+ float alpha,
+ float beta,
+ float k) {
+ error::wrap_c_api(
+ mkldnn_lrn_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ local_size,
+ alpha,
+ beta,
+ k),
+ "could not create a lrn forward descriptor");
+ }
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ int local_size,
+ float alpha,
+ float beta) {
+ error::wrap_c_api(
+ mkldnn_lrn_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ local_size,
+ alpha,
+ beta,
+ float(1.0)),
+ "could not create a lrn forward descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a lrn forward primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a src primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc workspace_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t ldesc;
+ const_mkldnn_primitive_desc_t const_ldesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(workspace_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
+ "could not clone a workspace primitive descriptor");
+ adesc.reset(ldesc);
+ return adesc;
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ lrn_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const memory &workspace,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a lrn forward primitive");
+ reset(result);
+ }
+
+ lrn_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a lrn forward primitive");
+ reset(result);
+ }
+};
+
+struct lrn_backward : public primitive {
+ struct desc {
+ mkldnn_lrn_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &data_desc,
+ const memory::desc &diff_data_desc,
+ int local_size,
+ float alpha,
+ float beta,
+ float k) {
+ error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
+ convert_to_c(aalgorithm),
+ &diff_data_desc.data,
+ &data_desc.data,
+ local_size,
+ alpha,
+ beta,
+ k),
+ "could not create a lrn backward descriptor");
+ }
+ desc(algorithm aalgorithm,
+ const memory::desc &data_desc,
+ const memory::desc &diff_data_desc,
+ int local_size,
+ float alpha,
+ float beta) {
+ error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
+ convert_to_c(aalgorithm),
+ &diff_data_desc.data,
+ &data_desc.data,
+ local_size,
+ alpha,
+ beta,
+ float(1.0)),
+ "could not create a lrn backward descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc,
+ const engine &aengine,
+ const lrn_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a backward lrn primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc diff_src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_src primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc workspace_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t ldesc;
+ const_mkldnn_primitive_desc_t const_ldesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(workspace_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&ldesc, const_ldesc),
+ "could not clone a workspace primitive descriptor");
+ adesc.reset(ldesc);
+ return adesc;
+ }
+
+ memory::primitive_desc diff_dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff_dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ lrn_backward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const primitive::at &workspace,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data, workspace.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a lrn backward primitive");
+ reset(result);
+ }
+
+ lrn_backward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const primitive::at &diff_dst,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data, diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a lrn backward primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_pooling Pooling
+/// @{
+
+struct pooling_forward : public primitive {
+ struct desc {
+ mkldnn_pooling_desc_t data;
+ desc(prop_kind aprop_kind,
+ algorithm aalgorithm,
+ const memory::desc &src_desc,
+ const memory::desc &dst_desc,
+ const memory::dims strides,
+ const memory::dims kernel,
+ const memory::dims padding_l,
+ const memory::dims padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(kernel);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(
+ mkldnn_pooling_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind),
+ convert_to_c(aalgorithm),
+ &src_desc.data,
+ &dst_desc.data,
+ &strides[0],
+ &kernel[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not init a forward pooling descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(const desc &adesc, const engine &aengine) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(
+ &result, &adesc.data, aengine.get(), nullptr),
+ "could not create a forward pooling primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc workspace_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(workspace_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a workspace primititve descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ memory::primitive_desc dst_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(dst_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a dst primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ pooling_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const memory &dst) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get(), nullptr};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a pooling forward primitive");
+ reset(result);
+ }
+
+ pooling_forward(const primitive_desc &aprimitive_desc,
+ const primitive::at &src,
+ const memory &dst,
+ const memory &workspace) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {src.data};
+ const_mkldnn_primitive_t outputs[] = {dst.get(), workspace.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a pooling forward primitive");
+ reset(result);
+ }
+};
+
+struct pooling_backward : public primitive {
+ struct desc {
+ mkldnn_pooling_desc_t data;
+ desc(algorithm aalgorithm,
+ const memory::desc &diff_src_desc,
+ const memory::desc &diff_dst_desc,
+ const memory::dims &strides,
+ const memory::dims &kernel,
+ const memory::dims &padding_l,
+ const memory::dims &padding_r,
+ const padding_kind apadding_kind) {
+ memory::validate_dims(strides);
+ memory::validate_dims(kernel);
+ memory::validate_dims(padding_l);
+ memory::validate_dims(padding_r);
+ error::wrap_c_api(mkldnn_pooling_backward_desc_init(
+ &data,
+ convert_to_c(aalgorithm),
+ &diff_src_desc.data,
+ &diff_dst_desc.data,
+ &strides[0],
+ &kernel[0],
+ &padding_l[0],
+ &padding_r[0],
+ mkldnn::convert_to_c(apadding_kind)),
+ "could not init a backward pooling descriptor");
+ }
+ };
+
+ struct primitive_desc : public handle {
+ primitive_desc(
+ const desc &adesc,
+ const engine &aengine,
+ const pooling_forward::primitive_desc &hint_fwd_primitive_desc) {
+ mkldnn_primitive_desc_t result;
+ error::wrap_c_api(
+ mkldnn_primitive_desc_create(&result,
+ &adesc.data,
+ aengine.get(),
+ hint_fwd_primitive_desc.get()),
+ "could not create a backward pooling primitive descriptor");
+ reset(result);
+ }
+
+ memory::primitive_desc diff_src_primitive_desc() const {
+ memory::primitive_desc adesc;
+ mkldnn_primitive_desc_t cdesc;
+ const_mkldnn_primitive_desc_t const_cdesc =
+ mkldnn_primitive_desc_query_pd(
+ get(), mkldnn::convert_to_c(diff_src_pd), 0);
+ error::wrap_c_api(mkldnn_primitive_desc_clone(&cdesc, const_cdesc),
+ "could not clone a diff src primitive descriptor");
+ adesc.reset(cdesc);
+ return adesc;
+ }
+
+ engine get_engine() { return engine::query(*this); }
+ };
+
+ pooling_backward(const primitive_desc &aprimitive_desc,
+ const primitive::at &diff_dst,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {diff_dst.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a pooling backward primitive");
+ reset(result);
+ }
+
+ pooling_backward(const primitive_desc &aprimitive_desc,
+ const primitive::at &diff_dst,
+ const primitive::at &workspace,
+ const memory &diff_src) {
+ mkldnn_primitive_t result;
+ mkldnn_primitive_at_t inputs[] = {diff_dst.data, workspace.data};
+ const_mkldnn_primitive_t outputs[] = {diff_src.get()};
+ error::wrap_c_api(mkldnn_primitive_create(
+ &result, aprimitive_desc.get(), inputs, outputs),
+ "could not create a pooling backward primitive");
+ reset(result);
+ }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_eltwise Eltwise
+/// @{
+
+struct eltwise_forward : public primitive {
+ struct desc {
+ mkldnn_eltwise_desc_t data;
+ template
+ desc(prop_kind aprop_kind,
+ algorithm alg_kind,
+ const memory::desc &src_desc,
+ T alpha = 0,
+ T beta = 0) {
+ error::wrap_c_api(
+ mkldnn_eltwise_forward_desc_init(&data,
+ mkldnn::convert_to_c(aprop_kind),
+ mkldnn::convert_to_c(alg_kind),
+ &src_desc.data,
+ static_cast(alpha),
+ static_cast(beta)),
+ "could not create a eltwise forward descriptor");
+ }
+
+ /** @deprecated: api backward compatibility for relu */
+ template
+ MKLDNN_DEPRECATED desc(prop_kind aprop_kind,
+ const memory::desc &src_desc,
+ T negative_slope)
+ : desc(aprop_kind, eltwise_relu, src_desc, negative_slope) {}
+ };
+
+ struct primitive_desc : public handle